diff --git a/.editorconfig b/.editorconfig index 2d1324b..a01cb17 100644 --- a/.editorconfig +++ b/.editorconfig @@ -6,7 +6,7 @@ charset = utf-8 # Python source files [*.py] -indent_style = tab +indent_style = space indent_size = 4 trim_trailing_whitespace = true insert_final_newline = true @@ -20,7 +20,7 @@ insert_final_newline = true # Code documentation [*.rst] -indent_style = tab +indent_style = space indent_size = 4 trim_trailing_whitespace = true insert_final_newline = true diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index e131cce..9a2458b 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -6,7 +6,7 @@ on: push: branches: [master, devel] tags: 'v*' # Push events to matching v*, i.e. v1.0, v20.15.10 - pull_request: + pull_request_target: branches: [master, devel] schedule: - cron: '0 6 1 * *' # once a month in the morning @@ -19,19 +19,21 @@ jobs: steps: - name: Checkout code uses: actions/checkout@v2 + with: + persist-credentials: true - - name: Setup Python 3.7 + - name: Setup Python 3.8 uses: actions/setup-python@v2 with: - python-version: 3.7 + python-version: 3.8 - name: Cache pip uses: actions/cache@v2 with: path: ~/.cache/pip - key: ${{ runner.os }}-pip-py3.7-${{ hashFiles('requirements.txt') }} + key: ${{ runner.os }}-pip-py3.8-${{ hashFiles('requirements.txt') }} restore-keys: | - ${{ runner.os }}-pip-py3.7- + ${{ runner.os }}-pip-py3.8- ${{ runner.os }}-pip- - name: Install dependencies @@ -55,7 +57,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, macos-latest, windows-latest] - python-version: [3.6, 3.7, 3.8] + python-version: [3.8] include: - os: ubuntu-latest pippath: ~/.cache/pip @@ -112,9 +114,11 @@ jobs: pip install codecov pytest-cov - name: Testing + continue-on-error: true run: pytest --cov - name: Upload coverage + continue-on-error: true uses: codecov/codecov-action@v1 with: env_vars: OS,PYTHON @@ -144,18 +148,18 @@ jobs: - name: Git LFS Pull run: git lfs pull - - name: Setup Python 3.7 + - name: Setup Python 3.8 uses: actions/setup-python@v2 with: - python-version: 3.7 + python-version: 3.8 - name: Cache pip uses: actions/cache@v2 with: path: ~/.cache/pip - key: ${{ runner.os }}-pip-py3.7-${{ hashFiles('requirements.txt') }} + key: ${{ runner.os }}-pip-py3.8-${{ hashFiles('requirements.txt') }} restore-keys: | - ${{ runner.os }}-pip-py3.7- + ${{ runner.os }}-pip-py3.8- ${{ runner.os }}-pip- - name: Install dependencies diff --git a/flows/__init__.py b/flows/__init__.py index 488a5bb..e8fc4ed 100644 --- a/flows/__init__.py +++ b/flows/__init__.py @@ -1,11 +1,9 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# flake8: noqa - +""" +FLOWS pipeline package +""" from .photometry import photometry from .catalogs import download_catalog from .visibility import visibility -from .config import load_config - from .version import get_version + __version__ = get_version(pep440=False) diff --git a/flows/aadc_db.py b/flows/aadc_db.py index 1e0a283..8d25ab9 100644 --- a/flows/aadc_db.py +++ b/flows/aadc_db.py @@ -13,55 +13,56 @@ import psycopg2 as psql from psycopg2.extras import DictCursor import getpass -from .config import load_config - -#-------------------------------------------------------------------------------------------------- -class AADC_DB(object): # pragma: no cover - """ - Connection to the central TASOC database. - - Attributes: - conn (`psycopg2.Connection` object): Connection to PostgreSQL database. - cursor (`psycopg2.Cursor` object): Cursor to use in database. - """ - - def __init__(self, username=None, password=None): - """ - Open connection to central TASOC database. - - If ``username`` or ``password`` is not provided or ``None``, - the user will be prompted for them. - - Parameters: - username (string or None, optional): Username for AADC database. - password (string or None, optional): Password for AADC database. - """ - - config = load_config() - - if username is None: - username = config.get('database', 'username', fallback=None) - if username is None: - default_username = getpass.getuser() - username = input('Username [%s]: ' % default_username) - if username == '': - username = default_username - - if password is None: - password = config.get('database', 'password', fallback=None) - if password is None: - password = getpass.getpass('Password: ') - - # Open database connection: - self.conn = psql.connect('host=10.28.0.127 user=' + username + ' password=' + password + ' dbname=db_aadc') - self.cursor = self.conn.cursor(cursor_factory=DictCursor) - - def close(self): - self.cursor.close() - self.conn.close() - - def __enter__(self): - return self - - def __exit__(self, *args, **kwargs): - self.close() +from tendrils.utils import load_config + + +# -------------------------------------------------------------------------------------------------- +class AADC_DB(object): # pragma: no cover + """ + Connection to the central TASOC database. + + Attributes: + conn (`psycopg2.Connection` object): Connection to PostgreSQL database. + cursor (`psycopg2.Cursor` object): Cursor to use in database. + """ + + def __init__(self, username=None, password=None): + """ + Open connection to central TASOC database. + + If ``username`` or ``password`` is not provided or ``None``, + the user will be prompted for them. + + Parameters: + username (string or None, optional): Username for AADC database. + password (string or None, optional): Password for AADC database. + """ + + config = load_config() + + if username is None: + username = config.get('database', 'username', fallback=None) + if username is None: + default_username = getpass.getuser() + username = input('Username [%s]: ' % default_username) + if username == '': + username = default_username + + if password is None: + password = config.get('database', 'password', fallback=None) + if password is None: + password = getpass.getpass('Password: ') + + # Open database connection: + self.conn = psql.connect('host=10.28.0.127 user=' + username + ' password=' + password + ' dbname=db_aadc') + self.cursor = self.conn.cursor(cursor_factory=DictCursor) + + def close(self): + self.cursor.close() + self.conn.close() + + def __enter__(self): + return self + + def __exit__(self, *args, **kwargs): + self.close() diff --git a/flows/api/__init__.py b/flows/api/__init__.py deleted file mode 100644 index 8ad4638..0000000 --- a/flows/api/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# flake8: noqa - -from .targets import get_targets, get_target, add_target -from .datafiles import get_datafile, get_datafiles -from .catalogs import get_catalog, get_catalog_missing -from .sites import get_site, get_all_sites -from .photometry_api import get_photometry, upload_photometry -from .set_photometry_status import set_photometry_status, cleanup_photometry_status -from .filters import get_filters -from .lightcurves import get_lightcurve diff --git a/flows/api/catalogs.py b/flows/api/catalogs.py deleted file mode 100644 index 6657b89..0000000 --- a/flows/api/catalogs.py +++ /dev/null @@ -1,124 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" - -.. codeauthor:: Rasmus Handberg -""" - -import astropy.units as u -from astropy.table import Table -from astropy.time import Time -import requests -from functools import lru_cache -from ..config import load_config - -#-------------------------------------------------------------------------------------------------- -@lru_cache(maxsize=10) -def get_catalog(target, radius=None, output='table'): - """ - - Parameters: - target (int or str): - radius (float, optional): Radius around target in degrees to return targets for. - outout (str, optional): Desired output format. Choises are 'table', 'dict', 'json'. - Default='table'. - - Returns: - dict: Dictionary with three members: - - 'target': Information about target. - - 'references': Table with information about reference stars close to target. - - 'avoid': Table with stars close to target which should be avoided in FOV selection. - - .. codeauthor:: Rasmus Handberg - """ - - assert output in ('table', 'json', 'dict'), "Invalid output format" - - # Get API token from config file: - config = load_config() - token = config.get('api', 'token', fallback=None) - if token is None: - raise RuntimeError("No API token has been defined") - - # - r = requests.get('https://flows.phys.au.dk/api/reference_stars.php', - params={'target': target}, - headers={'Authorization': 'Bearer ' + token}) - r.raise_for_status() - jsn = r.json() - - # Convert timestamps to actual Time objects: - jsn['target']['inserted'] = Time(jsn['target']['inserted'], scale='utc') - if jsn['target']['discovery_date'] is not None: - jsn['target']['discovery_date'] = Time(jsn['target']['discovery_date'], scale='utc') - - if output in ('json', 'dict'): - return jsn - - dict_tables = {} - - tab = Table( - names=('targetid', 'target_name', 'target_status', 'ra', 'decl', 'redshift', 'redshift_error', 'discovery_mag', 'catalog_downloaded', 'pointing_model_created', 'inserted', 'discovery_date', 'project', 'host_galaxy', 'ztf_id', 'sntype'), - dtype=('int32', 'str', 'str', 'float64', 'float64', 'float32', 'float32', 'float32', 'bool', 'bool', 'object', 'object', 'str', 'str', 'str', 'str'), - rows=[jsn['target']]) - - tab['ra'].description = 'Right ascension' - tab['ra'].unit = u.deg - tab['decl'].description = 'Declination' - tab['decl'].unit = u.deg - dict_tables['target'] = tab - - for table_name in ('references', 'avoid'): - tab = Table( - names=('starid', 'ra', 'decl', 'pm_ra', 'pm_dec', 'gaia_mag', 'gaia_bp_mag', 'gaia_rp_mag', 'gaia_variability', 'B_mag', 'V_mag', 'H_mag', 'J_mag', 'K_mag', 'u_mag', 'g_mag', 'r_mag', 'i_mag', 'z_mag', 'distance'), - dtype=('int64', 'float64', 'float64', 'float32', 'float32', 'float32', 'float32', 'float32', 'int32', 'float32', 'float32', 'float32', 'float32', 'float32', 'float32', 'float32', 'float32', 'float32', 'float32', 'float64'), - rows=jsn[table_name]) - - tab['starid'].description = 'Unique identifier in REFCAT2 catalog' - tab['ra'].description = 'Right ascension' - tab['ra'].unit = u.deg - tab['decl'].description = 'Declination' - tab['decl'].unit = u.deg - tab['pm_ra'].description = 'Proper motion in right ascension' - tab['pm_ra'].unit = u.mas/u.yr - tab['pm_dec'].description = 'Proper motion in declination' - tab['pm_dec'].unit = u.mas/u.yr - tab['distance'].description = 'Distance from object to target' - tab['distance'].unit = u.deg - - tab['gaia_mag'].description = 'Gaia G magnitude' - tab['gaia_bp_mag'].description = 'Gaia Bp magnitude' - tab['gaia_rp_mag'].description = 'Gaia Rp magnitude' - tab['gaia_variability'].description = 'Gaia variability classification' - tab['B_mag'].description = 'Johnson B magnitude' - tab['V_mag'].description = 'Johnson V magnitude' - tab['H_mag'].description = '2MASS H magnitude' - tab['J_mag'].description = '2MASS J magnitude' - tab['K_mag'].description = '2MASS K magnitude' - tab['u_mag'].description = 'u magnitude' - tab['g_mag'].description = 'g magnitude' - tab['r_mag'].description = 'r magnitude' - tab['i_mag'].description = 'i magnitude' - tab['z_mag'].description = 'z magnitude' - - # Add some meta-data to the table as well: - tab.meta['targetid'] = int(dict_tables['target']['targetid']) - - dict_tables[table_name] = tab - - return dict_tables - -#-------------------------------------------------------------------------------------------------- -def get_catalog_missing(): - - # Get API token from config file: - config = load_config() - token = config.get('api', 'token', fallback=None) - if token is None: - raise Exception("No API token has been defined") - - # - r = requests.get('https://flows.phys.au.dk/api/catalog_missing.php', - headers={'Authorization': 'Bearer ' + token}) - r.raise_for_status() - return r.json() diff --git a/flows/api/datafiles.py b/flows/api/datafiles.py deleted file mode 100644 index 5cf7e31..0000000 --- a/flows/api/datafiles.py +++ /dev/null @@ -1,79 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" - -.. codeauthor:: Rasmus Handberg -""" - -import requests -from datetime import datetime -from functools import lru_cache -from ..config import load_config - -#-------------------------------------------------------------------------------------------------- -@lru_cache(maxsize=10) -def get_datafile(fileid): - - # Get API token from config file: - config = load_config() - token = config.get('api', 'token', fallback=None) - if token is None: - raise RuntimeError("No API token has been defined") - - r = requests.get('https://flows.phys.au.dk/api/datafiles.php', - params={'fileid': fileid}, - headers={'Authorization': 'Bearer ' + token}) - r.raise_for_status() - jsn = r.json() - - # Parse some of the fields to Python objects: - jsn['inserted'] = datetime.strptime(jsn['inserted'], '%Y-%m-%d %H:%M:%S.%f') - jsn['lastmodified'] = datetime.strptime(jsn['lastmodified'], '%Y-%m-%d %H:%M:%S.%f') - - return jsn - -#-------------------------------------------------------------------------------------------------- -def get_datafiles(targetid=None, filt=None, minversion=None): - """ - Get list of data file IDs to be processed. - - Parameters: - targetid (int, optional): Target ID to process. - filt (str, optional): Filter the returned list: - - ``missing``: Return only data files that have not yet been processed. - - ``'all'``: Return all data files. - minversion (str, optional): Special filter matching files not processed at least with - the specified version (defined internally in API for now). - - Returns: - list: List of data files the can be processed. - - .. codeauthor:: Rasmus Handberg - """ - - # Validate input: - if filt is None: - filt = 'missing' - if filt not in ('missing', 'all', 'error'): - raise ValueError("Invalid filter specified: '%s'" % filt) - - # Get API token from config file: - config = load_config() - token = config.get('api', 'token', fallback=None) - if token is None: - raise RuntimeError("No API token has been defined") - - params = {} - if targetid is not None: - params['targetid'] = targetid - if minversion is not None: - params['minversion'] = minversion - params['filter'] = filt - - r = requests.get('https://flows.phys.au.dk/api/datafiles.php', - params=params, - headers={'Authorization': 'Bearer ' + token}) - r.raise_for_status() - jsn = r.json() - - return jsn diff --git a/flows/api/filters.py b/flows/api/filters.py deleted file mode 100644 index 05e690c..0000000 --- a/flows/api/filters.py +++ /dev/null @@ -1,35 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" - -.. codeauthor:: Rasmus Handberg -""" - -import requests -from functools import lru_cache -import astropy.units as u -from ..config import load_config - -#-------------------------------------------------------------------------------------------------- -@lru_cache(maxsize=10) -def get_filters(): - - # Get API token from config file: - config = load_config() - token = config.get('api', 'token', fallback=None) - if token is None: - raise RuntimeError("No API token has been defined") - - r = requests.get('https://flows.phys.au.dk/api/filters.php', - headers={'Authorization': 'Bearer ' + token}) - r.raise_for_status() - jsn = r.json() - - # Add units: - for f, val in jsn.items(): - if val.get('wavelength_center'): - val['wavelength_center'] *= u.nm - if val.get('wavelength_width'): - val['wavelength_width'] *= u.nm - - return jsn diff --git a/flows/api/lightcurves.py b/flows/api/lightcurves.py deleted file mode 100644 index a2eb22a..0000000 --- a/flows/api/lightcurves.py +++ /dev/null @@ -1,54 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -Fetch current lightcurve from Flows API. - -.. codeauthor:: Rasmus Handberg -""" - -import requests -import os.path -import tempfile -from astropy.table import Table -from ..config import load_config - -#-------------------------------------------------------------------------------------------------- -def get_lightcurve(target): - """ - Retrieve lightcurve from Flows server. - - Parameters: - target (int): Target to download lightcurve for. - - Returns: - :class:`astropy.table.Table`: Table containing lightcurve. - - TODO: - - Enable caching of files. - - .. codeauthor:: Rasmus Handberg - """ - - # Get API token from config file: - config = load_config() - token = config.get('api', 'token', fallback=None) - if token is None: - raise RuntimeError("No API token has been defined") - - # Send query to the Flows API: - params = {'target': target} - r = requests.get('https://flows.phys.au.dk/api/lightcurve.php', - params=params, - headers={'Authorization': 'Bearer ' + token}) - r.raise_for_status() - - # Create tempory directory and save the file into there, - # then open the file as a Table: - with tempfile.TemporaryDirectory() as tmpdir: - tmpfile = os.path.join(tmpdir, 'table.ecsv') - with open(tmpfile, 'w') as fid: - fid.write(r.text) - - tab = Table.read(tmpfile, format='ascii.ecsv') - - return tab diff --git a/flows/api/photometry_api.py b/flows/api/photometry_api.py deleted file mode 100644 index fdddf08..0000000 --- a/flows/api/photometry_api.py +++ /dev/null @@ -1,180 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -Upload photometry results to Flows server. - -.. codeauthor:: Rasmus Handberg -""" - -import logging -import os -import zipfile -import requests -import tempfile -import shutil -import glob -from tqdm import tqdm -from astropy.table import Table -from .. import api -from ..config import load_config -from ..utilities import get_filehash - -#-------------------------------------------------------------------------------------------------- -def get_photometry(photid): - """ - Retrieve lightcurve from Flows server. - - Please note that it can significantly speed up repeated calls to this function - to specify a cache directory in the config-file under api -> photometry_cache. - This will download the files only once and store them in this local cache for - use in subsequent calls. - - Parameters: - photid (int): Fileid for the photometry file. - - Returns: - :class:`astropy.table.Table`: Table containing photometry. - - .. codeauthor:: Rasmus Handberg - """ - - # Get API token from config file: - config = load_config() - token = config.get('api', 'token', fallback=None) - if token is None: - raise RuntimeError("No API token has been defined") - - # Determine where to store the downloaded file: - photcache = config.get('api', 'photometry_cache', fallback=None) - tmpdir = None - if photcache is not None: - photcache = os.path.abspath(photcache) - if not os.path.isdir(photcache): - raise FileNotFoundError(f"Photometry cache directory does not exist: {photcache}") - else: - tmpdir = tempfile.TemporaryDirectory(prefix='flows-api-get_photometry-') - photcache = tmpdir.name - - # Construct path to the photometry file in the cache: - photfile = os.path.join(photcache, f'photometry-{photid:d}.ecsv') - - if not os.path.isfile(photfile): - # Send query to the Flows API: - params = {'fileid': photid} - r = requests.get('https://flows.phys.au.dk/api/download_photometry.php', - params=params, - headers={'Authorization': 'Bearer ' + token}) - r.raise_for_status() - - # Create tempory directory and save the file into there, - # then open the file as a Table: - with open(photfile, 'w') as fid: - fid.write(r.text) - - # Read the photometry file: - tab = Table.read(photfile, format='ascii.ecsv') - - # Explicitly cleanup the tempoary directory if it was created: - if tmpdir: - tmpdir.cleanup() - - return tab - -#-------------------------------------------------------------------------------------------------- -def upload_photometry(fileid, delete_completed=False): - """ - Upload photometry results to Flows server. - - This will make the uploaded photometry the active/newest/best photometry and - be used in plots and shown on the website. - - Parameters: - fileid (int): File ID of photometry to upload to server. - delete_completed (bool, optional): Delete the photometry from the local - working directory if the upload was successful. Default=False. - - .. codeauthor:: Rasmus Handberg - """ - - logger = logging.getLogger(__name__) - tqdm_settings = {'disable': None if logger.isEnabledFor(logging.INFO) else True} - - # Use API to get the datafile information: - datafile = api.get_datafile(fileid) - - # Get API token from config file: - config = load_config() - token = config.get('api', 'token', fallback=None) - if token is None: - raise RuntimeError("No API token has been defined") - photdir_root = config.get('photometry', 'output', fallback='.') - - # Find the photometry output directory for this fileid: - photdir = os.path.join(photdir_root, datafile['target_name'], f'{fileid:05d}') - if not os.path.isdir(photdir): - # Do a last check, to ensure that we have not just added the wrong number of zeros - # to the directory name: - found_photdir = [] - for d in os.listdir(os.path.join(photdir_root, datafile['target_name'])): - if d.isdigit() and int(d) == fileid and os.path.isdir(d): - found_photdir.append(os.path.join(photdir_root, datafile['target_name'], d)) - # If we only found one, use it, otherwise throw an exception: - if len(found_photdir) == 1: - photdir = found_photdir[0] - elif len(found_photdir) > 1: - raise RuntimeError(f"Several photometry output found for fileid={fileid}. \ - You need to do a cleanup of the photometry output directories.") - else: - raise FileNotFoundError(photdir) - - # Make sure required files are actually there: - photdir = os.path.abspath(photdir) - files_existing = os.listdir(photdir) - if 'photometry.ecsv' not in files_existing: - raise FileNotFoundError(os.path.join(photdir, 'photometry.ecsv')) - if 'photometry.log' not in files_existing: - raise FileNotFoundError(os.path.join(photdir, 'photometry.log')) - - # Create list of files to be uploaded: - files = [ - os.path.join(photdir, 'photometry.ecsv'), - os.path.join(photdir, 'photometry.log') - ] - files += glob.glob(os.path.join(photdir, '*.png')) - - # Create ZIP file: - with tempfile.TemporaryDirectory(prefix='flows-upload-') as tmpdir: - # Create ZIP-file within the temp directory: - fpath_zip = os.path.join(tmpdir, f'{fileid:05d}.zip') - - # Create ZIP file with all the files: - with zipfile.ZipFile(fpath_zip, 'w', allowZip64=True) as z: - for f in tqdm(files, desc=f'Zipping {fileid:d}', **tqdm_settings): - logger.debug('Zipping %s', f) - z.write(f, os.path.basename(f)) - - # Change the name of the uploaded file to contain the file hash: - fhash = get_filehash(fpath_zip) - fname_zip = f'{fileid:05d}-{fhash:s}.zip' - - # Send file to the API: - logger.info("Uploading to server...") - with open(fpath_zip, 'rb') as fid: - r = requests.post('https://flows.phys.au.dk/api/upload_photometry.php', - params={'fileid': fileid}, - files={'file': (fname_zip, fid, 'application/zip')}, - headers={'Authorization': 'Bearer ' + token}) - - # Check the returned data from the API: - if r.text.strip() != 'OK': - logger.error(r.text) - raise RuntimeError("An error occurred while uploading photometry: " + r.text) - r.raise_for_status() - - # If we have made it this far, the upload must have been a success: - if delete_completed: - if set([os.path.basename(f) for f in files]) == set(os.listdir(photdir)): - logger.info("Deleting photometry from workdir: '%s'", photdir) - shutil.rmtree(photdir, ignore_errors=True) - else: - logger.warning("Not deleting photometry from workdir: '%s'", photdir) diff --git a/flows/api/set_photometry_status.py b/flows/api/set_photometry_status.py deleted file mode 100644 index b0fc572..0000000 --- a/flows/api/set_photometry_status.py +++ /dev/null @@ -1,86 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" - -.. codeauthor:: Rasmus Handberg -""" - -import logging -import requests -from ..config import load_config - -#-------------------------------------------------------------------------------------------------- -def set_photometry_status(fileid, status): - """ - Set photometry status. - - Parameters: - fileid (int): - status (str): Choises are 'running', 'error' or 'done'. - - .. codeauthor:: Rasmus Handberg - """ - # Validate the input: - logger = logging.getLogger(__name__) - if status not in ('running', 'error', 'abort', 'ingest', 'done'): - raise ValueError('Invalid status') - - # Get API token from config file: - config = load_config() - i_am_pipeline = config.getboolean('api', 'pipeline', fallback=False) - if not i_am_pipeline: - logger.debug("Not setting status since user is not pipeline") - return False - - # Get API token from config file: - token = config.get('api', 'token', fallback=None) - if token is None: - raise RuntimeError("No API token has been defined") - - # Send HTTP request to FLOWS server: - r = requests.get('https://flows.phys.au.dk/api/set_photometry_status.php', - params={'fileid': fileid, 'status': status}, - headers={'Authorization': 'Bearer ' + token}) - r.raise_for_status() - res = r.text.strip() - - if res != 'OK': - raise RuntimeError(res) - - return True - -#-------------------------------------------------------------------------------------------------- -def cleanup_photometry_status(): - """ - Perform a cleanup of the photometry status indicator. - - This will change all processes still marked as "running" - to "abort" if they have been running for more than a day. - - .. codeauthor:: Rasmus Handberg - """ - # Validate the input: - logger = logging.getLogger(__name__) - - # Get API token from config file: - config = load_config() - i_am_pipeline = config.getboolean('api', 'pipeline', fallback=False) - if not i_am_pipeline: - logger.debug("Not setting status since user is not pipeline") - return False - - # Get API token from config file: - token = config.get('api', 'token', fallback=None) - if token is None: - raise RuntimeError("No API token has been defined") - - # Send HTTP request to FLOWS server: - r = requests.get('https://flows.phys.au.dk/api/cleanup_photometry_status.php', - headers={'Authorization': 'Bearer ' + token}) - r.raise_for_status() - res = r.text.strip() - - if res != 'OK': - raise RuntimeError(res) - - return True diff --git a/flows/api/sites.py b/flows/api/sites.py deleted file mode 100644 index 053310d..0000000 --- a/flows/api/sites.py +++ /dev/null @@ -1,54 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" - -.. codeauthor:: Rasmus Handberg -""" - -import requests -from functools import lru_cache -import astropy.units as u -from astropy.coordinates import EarthLocation -from ..config import load_config - -#-------------------------------------------------------------------------------------------------- -@lru_cache(maxsize=10) -def get_site(siteid): - - # Get API token from config file: - config = load_config() - token = config.get('api', 'token', fallback=None) - if token is None: - raise RuntimeError("No API token has been defined") - - r = requests.get('https://flows.phys.au.dk/api/sites.php', - params={'siteid': siteid}, - headers={'Authorization': 'Bearer ' + token}) - r.raise_for_status() - jsn = r.json() - - # Special derived objects: - jsn['EarthLocation'] = EarthLocation(lat=jsn['latitude']*u.deg, lon=jsn['longitude']*u.deg, height=jsn['elevation']*u.m) - - return jsn - -#-------------------------------------------------------------------------------------------------- -@lru_cache(maxsize=1) -def get_all_sites(): - - # Get API token from config file: - config = load_config() - token = config.get('api', 'token', fallback=None) - if token is None: - raise RuntimeError("No API token has been defined") - - r = requests.get('https://flows.phys.au.dk/api/sites.php', - headers={'Authorization': 'Bearer ' + token}) - r.raise_for_status() - jsn = r.json() - - # Special derived objects: - for site in jsn: - site['EarthLocation'] = EarthLocation(lat=site['latitude']*u.deg, lon=site['longitude']*u.deg, height=site['elevation']*u.m) - - return jsn diff --git a/flows/api/targets.py b/flows/api/targets.py deleted file mode 100644 index 4d04216..0000000 --- a/flows/api/targets.py +++ /dev/null @@ -1,162 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -Get information about targets in Flows. - -.. codeauthor:: Rasmus Handberg -""" - -import re -from datetime import datetime -import pytz -from astropy.time import Time -import requests -from functools import lru_cache -from ..config import load_config - -#-------------------------------------------------------------------------------------------------- -@lru_cache(maxsize=10) -def get_target(target): - - # Get API token from config file: - config = load_config() - token = config.get('api', 'token', fallback=None) - if token is None: - raise RuntimeError("No API token has been defined") - - r = requests.get('https://flows.phys.au.dk/api/targets.php', - params={'target': target}, - headers={'Authorization': 'Bearer ' + token}) - r.raise_for_status() - jsn = r.json() - - # Parse some of the fields to Python objects: - jsn['inserted'] = datetime.strptime(jsn['inserted'], '%Y-%m-%d %H:%M:%S.%f') - if jsn['discovery_date']: - jsn['discovery_date'] = Time(jsn['discovery_date'], format='iso', scale='utc') - - return jsn - -#-------------------------------------------------------------------------------------------------- -@lru_cache(maxsize=1) -def get_targets(): - - # Get API token from config file: - config = load_config() - token = config.get('api', 'token', fallback=None) - if token is None: - raise RuntimeError("No API token has been defined") - - r = requests.get('https://flows.phys.au.dk/api/targets.php', - headers={'Authorization': 'Bearer ' + token}) - r.raise_for_status() - jsn = r.json() - - # Parse some of the fields to Python objects: - for tgt in jsn: - tgt['inserted'] = datetime.strptime(tgt['inserted'], '%Y-%m-%d %H:%M:%S.%f') - if tgt['discovery_date']: - tgt['discovery_date'] = Time(tgt['discovery_date'], format='iso', scale='utc') - - return jsn - -#-------------------------------------------------------------------------------------------------- -def add_target(name, coord, redshift=None, redshift_error=None, discovery_date=None, - discovery_mag=None, host_galaxy=None, ztf=None, sntype=None, status='candidate', - project='flows'): - """ - Add new candidate or target. - - Coordinates are specified using an Astropy SkyCoord object, which can be - created in the following way: - - coord = SkyCoord(ra=19.1, dec=89.00001, unit='deg', frame='icrs') - - The easiest way is to specify ``discovery_date`` as an Astropy Time object: - - discovery_date = Time('2020-01-02 00:00:00', format='iso', scale='utc') - - Alternatively, you can also specify it as a :class:`datetime.datetime` object, - but some care has to be taken with specifying the correct timezone: - - discovery_date = datetime.strptime('2020-01-02 00:00:00', '%Y-%m-%d %H:%M:%S%f') - discovery_date = pytz.timezone('America/New_York').localize(ddate) - - Lastly, it can be given as a simple date-string of the following form, - but here the data has to be given in UTC: - - discovery_date = '2020-01-02 23:56:02.123' - - Parameters: - name (str): Name of target. Must be of the form "YYYYxyz", where YYYY is the year, - and xyz are letters. - coord (:class:ʼastropy.coordinates.SkyCoordʼ): Sky coordinates of target. - redshift (float, optional): Redshift. - redshift_error (float, optional): Uncertainty on redshift. - discovery_date (:class:`astropy.time.Time`, :class:`datetime.datetime` or str, optional): - discovery_mag (float, optional): Magnitude at time of discovery. - host_galaxy (str, optional): Host galaxy name. - sntype (str, optional): Supernovae type (e.g. Ia, Ib, II). - ztf (str, optional): ZTF identifier. - status (str, optional): - project (str, optional): - - Returns: - int: New target identifier in Flows system. - - .. codeauthor:: Rasmus Handberg - """ - # Check and convert input: - if not re.match(r'^[12]\d{3}([A-Z]|[a-z]{2,4})$', name.strip()): - raise ValueError("Invalid target name.") - - if redshift is None and redshift_error is not None: - raise ValueError("Redshift error specified without redshift value") - - if isinstance(discovery_date, Time): - discovery_date = discovery_date.utc.iso - elif isinstance(discovery_date, datetime): - discovery_date = discovery_date.astimezone(pytz.timezone('UTC')) - discovery_date = discovery_date.strftime('%Y-%m-%d %H:%M:%S%f') - elif isinstance(discovery_date, str): - discovery_date = datetime.strptime(discovery_date, '%Y-%m-%d %H:%M:%S%f') - discovery_date = discovery_date.strftime('%Y-%m-%d %H:%M:%S%f') - - if status not in ('candidate', 'target'): - raise ValueError("Invalid target status.") - - # Get API token from config file: - config = load_config() - token = config.get('api', 'token', fallback=None) - if token is None: - raise RuntimeError("No API token has been defined") - - # Gather parameters to be sent to API: - params = { - 'targetid': 0, - 'target_name': name.strip(), - 'ra': coord.icrs.ra.deg, - 'decl': coord.icrs.dec.deg, - 'redshift': redshift, - 'redshift_error': redshift_error, - 'discovery_date': discovery_date, - 'discovery_mag': discovery_mag, - 'host_galaxy': host_galaxy, - 'project': project, - 'ztf_id': ztf, - 'target_status': status, - 'sntype': sntype - } - - # Post the request to the API: - r = requests.post('https://flows.phys.au.dk/api/targets_add.php', - data=params, - headers={'Authorization': 'Bearer ' + token}) - r.raise_for_status() - jsn = r.json() - - # Check for errors: - if jsn['errors'] is not None: - raise RuntimeError(f"Adding target '{name}' resulted in an error: {jsn['errors']}") - - return int(jsn['targetid']) diff --git a/flows/catalogs.py b/flows/catalogs.py index 264fc23..a25b7e0 100644 --- a/flows/catalogs.py +++ b/flows/catalogs.py @@ -20,617 +20,587 @@ from astropy.table import Table, MaskedColumn from astroquery.sdss import SDSS from astroquery.simbad import Simbad -from .config import load_config +from tendrils.utils import load_config, query_ztf_id from .aadc_db import AADC_DB -from .ztf import query_ztf_id -#-------------------------------------------------------------------------------------------------- + +# -------------------------------------------------------------------------------------------------- class CasjobsError(RuntimeError): - pass + pass + -#-------------------------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------------------------- class CasjobsMemoryError(RuntimeError): - pass + pass + -#-------------------------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------------------------- def floatval(value): - return None if value == '' or value == 'NA' or value == '0' else float(value) + return None if value == '' or value == 'NA' or value == '0' else float(value) -#-------------------------------------------------------------------------------------------------- + +# -------------------------------------------------------------------------------------------------- def intval(value): - return None if value == '' else int(value) + return None if value == '' else int(value) + -#-------------------------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------------------------- def configure_casjobs(overwrite=False): - """ - Set up CasJobs if needed. - - Parameters: - overwrite (bool, optional): Overwrite existing configuration. Default (False) is to not - overwrite existing configuration. - - .. codeauthor:: Rasmus Handberg - """ - - __dir__ = os.path.dirname(os.path.realpath(__file__)) - casjobs_config = os.path.join(__dir__, 'casjobs', 'CasJobs.config') - if os.path.isfile(casjobs_config) and not overwrite: - return - - config = load_config() - wsid = config.get('casjobs', 'wsid', fallback=None) - passwd = config.get('casjobs', 'password', fallback=None) - if wsid is None or passwd is None: - raise CasjobsError("CasJobs WSID and PASSWORD not in config.ini") - - try: - with open(casjobs_config, 'w') as fid: - fid.write("wsid={0:s}\n".format(wsid)) - fid.write("password={0:s}\n".format(passwd)) - fid.write("default_target=HLSP_ATLAS_REFCAT2\n") - fid.write("default_queue=1\n") - fid.write("default_days=1\n") - fid.write("verbose=false\n") - fid.write("debug=false\n") - fid.write("jobs_location=http://mastweb.stsci.edu/gcasjobs/services/jobs.asmx\n") - except: # noqa: E722, pragma: no cover - if os.path.isfile(casjobs_config): - os.remove(casjobs_config) - -#-------------------------------------------------------------------------------------------------- -def query_casjobs_refcat2(coo_centre, radius=24*u.arcmin): - """ - Uses the CasJobs program to do a cone-search around the position. - - Will first attempt to do single large cone-search, and if that - fails because of CasJobs memory limits, will sub-divide the cone - into smaller queries. - - Parameters: - coo_centre (:class:`astropy.coordinates.SkyCoord`): Coordinates of centre of search cone. - radius (Angle, optional): Search radius. Default is 24 arcmin. - - Returns: - list: List of dicts with the REFCAT2 information. - - .. codeauthor:: Rasmus Handberg - """ - - logger = logging.getLogger(__name__) - if isinstance(radius, (float, int)): - radius *= u.deg - - try: - results = _query_casjobs_refcat2(coo_centre, radius=radius) - except CasjobsMemoryError: - logger.debug("CasJobs failed with memory error. Trying to use smaller radii.") - results = _query_casjobs_refcat2_divide_and_conquer(coo_centre, radius=radius) - - # Remove duplicate entries: - _, indx = np.unique([res['starid'] for res in results], return_index=True) - results = [results[k] for k in indx] - - # Trim away anything outside radius: - ra = [res['ra'] for res in results] - decl = [res['decl'] for res in results] - coords = SkyCoord(ra=ra, dec=decl, unit='deg', frame='icrs') - sep = coords.separation(coo_centre) - results = [res for k,res in enumerate(results) if sep[k] <= radius] - - logger.debug("Found %d unique results", len(results)) - return results - -#-------------------------------------------------------------------------------------------------- + """ + Set up CasJobs if needed. + + Parameters: + overwrite (bool, optional): Overwrite existing configuration. Default (False) is to not + overwrite existing configuration. + + .. codeauthor:: Rasmus Handberg + """ + + __dir__ = os.path.dirname(os.path.realpath(__file__)) + casjobs_config = os.path.join(__dir__, 'casjobs', 'CasJobs.config') + if os.path.isfile(casjobs_config) and not overwrite: + return + + config = load_config() + wsid = config.get('casjobs', 'wsid', fallback=None) + passwd = config.get('casjobs', 'password', fallback=None) + if wsid is None or passwd is None: + raise CasjobsError("CasJobs WSID and PASSWORD not in config.ini") + + try: + with open(casjobs_config, 'w') as fid: + fid.write("wsid={0:s}\n".format(wsid)) + fid.write("password={0:s}\n".format(passwd)) + fid.write("default_target=HLSP_ATLAS_REFCAT2\n") + fid.write("default_queue=1\n") + fid.write("default_days=1\n") + fid.write("verbose=false\n") + fid.write("debug=false\n") + fid.write("jobs_location=http://mastweb.stsci.edu/gcasjobs/services/jobs.asmx\n") + except: # noqa: E722, pragma: no cover + if os.path.isfile(casjobs_config): + os.remove(casjobs_config) + + +# -------------------------------------------------------------------------------------------------- +def query_casjobs_refcat2(coo_centre, radius=24 * u.arcmin): + """ + Uses the CasJobs program to do a cone-search around the position. + + Will first attempt to do single large cone-search, and if that + fails because of CasJobs memory limits, will sub-divide the cone + into smaller queries. + + Parameters: + coo_centre (:class:`astropy.coordinates.SkyCoord`): Coordinates of centre of search cone. + radius (Angle, optional): Search radius. Default is 24 arcmin. + + Returns: + list: List of dicts with the REFCAT2 information. + + .. codeauthor:: Rasmus Handberg + """ + + logger = logging.getLogger(__name__) + if isinstance(radius, (float, int)): + radius *= u.deg + + try: + results = _query_casjobs_refcat2(coo_centre, radius=radius) + except CasjobsMemoryError: + logger.debug("CasJobs failed with memory error. Trying to use smaller radii.") + results = _query_casjobs_refcat2_divide_and_conquer(coo_centre, radius=radius) + + # Remove duplicate entries: + _, indx = np.unique([res['starid'] for res in results], return_index=True) + results = [results[k] for k in indx] + + # Trim away anything outside radius: + ra = [res['ra'] for res in results] + decl = [res['decl'] for res in results] + coords = SkyCoord(ra=ra, dec=decl, unit='deg', frame='icrs') + sep = coords.separation(coo_centre) + results = [res for k, res in enumerate(results) if sep[k] <= radius] + + logger.debug("Found %d unique results", len(results)) + return results + + +# -------------------------------------------------------------------------------------------------- def _query_casjobs_refcat2_divide_and_conquer(coo_centre, radius): - logger = logging.getLogger(__name__) - - # Just put in a stop criterion to avoid infinite recursion: - if radius < 0.04*u.deg: - raise Exception("Too many subdivides") - - # Search central cone: - try: - results = _query_casjobs_refcat2(coo_centre, radius=0.5*radius) - except CasjobsMemoryError: - logger.debug("CasJobs failed with memory error. Trying to use smaller radii.") - results = _query_casjobs_refcat2_divide_and_conquer(coo_centre, radius=0.5*radius) - - # Search six cones around central cone: - for n in range(6): - # FIXME: The 0.8 here is kind of a guess. There should be an analytic solution - new = SkyCoord( - ra=coo_centre.ra.deg + 0.8 * Angle(radius).deg * np.cos(n*60*np.pi/180), - dec=coo_centre.dec.deg + 0.8 * Angle(radius).deg * np.sin(n*60*np.pi/180), - unit='deg', frame='icrs') - - try: - results += _query_casjobs_refcat2(new, radius=0.5*radius) - except CasjobsMemoryError: - logger.debug("CasJobs failed with memory error. Trying to use smaller radii.") - results += _query_casjobs_refcat2_divide_and_conquer(new, radius=0.5*radius) - - return results - -#-------------------------------------------------------------------------------------------------- -def _query_casjobs_refcat2(coo_centre, radius=24*u.arcmin): - """ - Uses the CasJobs program to do a cone-search around the position. - - Parameters: - coo_centre (:class:`astropy.coordinates.SkyCoord`): Coordinates of centre of search cone. - radius (Angle, optional): Search radius. Default is 24 arcmin. - - Returns: - list: List of dicts with the REFCAT2 information. - - .. codeauthor:: Rasmus Handberg - """ - - logger = logging.getLogger(__name__) - if isinstance(radius, (float, int)): - radius *= u.deg - - sql = "SELECT r.* FROM fGetNearbyObjEq({ra:f}, {dec:f}, {radius:f}) AS n INNER JOIN HLSP_ATLAS_REFCAT2.refcat2 AS r ON n.objid=r.objid ORDER BY n.distance;".format( - ra=coo_centre.ra.deg, - dec=coo_centre.dec.deg, - radius=Angle(radius).deg - ) - logger.debug(sql) - - # Make sure that CasJobs have been configured: - configure_casjobs() - - # The command to run the casjobs script: - # BEWARE: This may change in the future without warning - it has before! - cmd = 'java -jar casjobs.jar execute "{0:s}"'.format(sql) - - # Execute the command: - cmd = shlex.split(cmd) - directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'casjobs') - proc = subprocess.Popen(cmd, cwd=directory, stdout=subprocess.PIPE, universal_newlines=True) - stdout, stderr = proc.communicate() - output = stdout.split("\n") - - # build list of all kois from output from the CasJobs-script: - error_thrown = False - results = [] - for line in output: - line = line.strip() - if line == '': - continue - if 'ERROR' in line: - error_thrown = True - break - - row = line.split(',') - if len(row) == 45 and row[0] != '[objid]:Integer': - results.append({ - 'starid': int(row[0]), - 'ra': floatval(row[1]), - 'decl': floatval(row[2]), - 'pm_ra': floatval(row[5]), - 'pm_dec': floatval(row[7]), - 'gaia_mag': floatval(row[9]), - 'gaia_bp_mag': floatval(row[11]), - 'gaia_rp_mag': floatval(row[13]), - 'gaia_variability': intval(row[17]), - 'g_mag': floatval(row[22]), - 'r_mag': floatval(row[26]), - 'i_mag': floatval(row[30]), - 'z_mag': floatval(row[34]), - 'J_mag': floatval(row[39]), - 'H_mag': floatval(row[41]), - 'K_mag': floatval(row[43]), - }) - - if error_thrown: - error_msg = '' - for line in output: - if len(line.strip()) > 0: - error_msg += line.strip() + "\n" - - logger.debug("Error Msg: %s", error_msg) - if 'query results exceed memory limit' in error_msg.lower(): - raise CasjobsMemoryError("Query results exceed memory limit") - else: - raise CasjobsError("ERROR detected in CasJobs: " + error_msg) - - if not results: - raise CasjobsError("Could not find anything on CasJobs") - - logger.debug("Found %d results", len(results)) - return results - -#-------------------------------------------------------------------------------------------------- -def query_apass(coo_centre, radius=24*u.arcmin): - """ - Queries APASS catalog using cone-search around the position. - - Parameters: - coo_centre (:class:`astropy.coordinates.SkyCoord`): Coordinates of centre of search cone. - radius (float, optional): - - Returns: - list: List of dicts with the APASS information. - - .. codeauthor:: Rasmus Handberg - """ - - # https://vizier.u-strasbg.fr/viz-bin/VizieR-3?-source=II/336 - - if isinstance(radius, (float, int)): - radius *= u.deg - - data = { - 'ra': coo_centre.icrs.ra.deg, - 'dec': coo_centre.icrs.dec.deg, - 'radius': Angle(radius).deg, - 'outtype': '1' - } - - res = requests.post('https://www.aavso.org/cgi-bin/apass_dr10_download.pl', data=data) - res.raise_for_status() - - results = [] - - lines = res.text.split("\n") - #header = lines[0] - - for line in lines[1:]: - if line.strip() == '': continue - row = line.strip().split(',') - - results.append({ - 'ra': floatval(row[0]), - 'decl': floatval(row[2]), - 'V_mag': floatval(row[4]), - 'B_mag': floatval(row[7]), - 'u_mag': floatval(row[10]), - 'g_mag': floatval(row[13]), - 'r_mag': floatval(row[16]), - 'i_mag': floatval(row[19]), - 'z_mag': floatval(row[22]), - 'Y_mag': floatval(row[25]) - }) - - return results - -#-------------------------------------------------------------------------------------------------- -def query_sdss(coo_centre, radius=24*u.arcmin, dr=16, clean=True): - """ - Queries SDSS catalog using cone-search around the position. - - Parameters: - coo_centre (:class:`astropy.coordinates.SkyCoord`): Coordinates of centre of search cone. - radius (float, optional): - dr (int, optional): SDSS Data Release to query. Default=16. - clean (bool, optional): Clean results for stars only and no other problems. - - Returns: - tuple: - - :class:`astropy.table.Table`: Table with SDSS information. - - :class:`astropy.coordinates.SkyCoord`: Sky coordinates for SDSS objects. - - .. codeauthor:: Emir Karamehmetoglu - .. codeauthor:: Rasmus Handberg - """ - - if isinstance(radius, (float, int)): - radius *= u.deg - - AT_sdss = SDSS.query_region(coo_centre, - photoobj_fields=['type', 'clean', 'ra', 'dec', 'psfMag_u'], - data_release=dr, - timeout=600, - radius=radius) - - if AT_sdss is None: - return None, None - - if clean: - # Clean SDSS following https://www.sdss.org/dr12/algorithms/photo_flags_recommend/ - # 6 == star, clean means remove interp, edge, suspicious defects, deblending problems, duplicates. - AT_sdss = AT_sdss[(AT_sdss['type'] == 6) & (AT_sdss['clean'] == 1)] - - # Remove these columns since they are no longer needed: - AT_sdss.remove_columns(['type', 'clean']) - - if len(AT_sdss) == 0: - return None, None - - # Create SkyCoord object with the coordinates: - sdss = SkyCoord( - ra=AT_sdss['ra'], - dec=AT_sdss['dec'], - unit=u.deg, - frame='icrs') - - return AT_sdss, sdss - -#-------------------------------------------------------------------------------------------------- -def query_simbad(coo_centre, radius=24*u.arcmin): - """ - Query SIMBAD using cone-search around the position using astroquery. - - Parameters: - coo_centre (:class:`astropy.coordinates.SkyCoord`): Coordinates of centre of search cone. - radius (float, optional): - - Returns: - list: Astropy Table with SIMBAD information. - - .. codeauthor:: Rasmus Handberg - """ - - s = Simbad() - s.ROW_LIMIT = 0 - s.remove_votable_fields('coordinates') - s.add_votable_fields('ra(d;A;ICRS;J2000)', 'dec(d;D;ICRS;2000)', 'pmra', 'pmdec') - s.add_votable_fields('otype') - s.add_votable_fields('flux(B)', 'flux(V)', 'flux(R)', 'flux(I)', 'flux(J)', 'flux(H)', 'flux(K)') - s.add_votable_fields('flux(u)', 'flux(g)', 'flux(r)', 'flux(i)', 'flux(z)') - - rad = Angle(radius).arcmin - with warnings.catch_warnings(): - warnings.filterwarnings('ignore', category=UserWarning) - results = s.query_criteria(f'region(circle, icrs, {coo_centre.icrs.ra.deg:.10f} {coo_centre.icrs.dec.deg:+.10f}, {rad}m)', otypes='Star') - - if not results: - return None, None - - # Rename columns: - results.rename_column('MAIN_ID', 'main_id') - results.rename_column('RA_d_A_ICRS_J2000', 'ra') - results.rename_column('DEC_d_D_ICRS_2000', 'dec') - results.rename_column('PMRA', 'pmra') - results.rename_column('PMDEC', 'pmdec') - results.rename_column('FLUX_B', 'B_mag') - results.rename_column('FLUX_V', 'V_mag') - results.rename_column('FLUX_R', 'R_mag') - results.rename_column('FLUX_I', 'I_mag') - results.rename_column('FLUX_J', 'J_mag') - results.rename_column('FLUX_H', 'H_mag') - results.rename_column('FLUX_K', 'K_mag') - results.rename_column('FLUX_u', 'u_mag') - results.rename_column('FLUX_g', 'g_mag') - results.rename_column('FLUX_r', 'r_mag') - results.rename_column('FLUX_i', 'i_mag') - results.rename_column('FLUX_z', 'z_mag') - results.rename_column('OTYPE', 'otype') - results.remove_column('SCRIPT_NUMBER_ID') - results.sort(['V_mag', 'B_mag', 'H_mag']) - - # Filter out object types which shouldn'r really be in there anyway: - indx = (results['otype'] == 'Galaxy') | (results['otype'] == 'LINER') | (results['otype'] == 'SN') - results = results[~indx] - - if len(results) == 0: - return None, None - - # Build sky coordinates object: - simbad = SkyCoord( - ra=results['ra'], - dec=results['dec'], - pm_ra_cosdec=results['pmra'], - pm_dec=results['pmdec'], - frame='icrs', - obstime='J2000') - - return results, simbad - -#-------------------------------------------------------------------------------------------------- -def query_skymapper(coo_centre, radius=24*u.arcmin): - """ - Queries SkyMapper catalog using cone-search around the position. - - Parameters: - coo_centre (:class:`astropy.coordinates.SkyCoord`): Coordinates of centre of search cone. - radius (float, optional): - - Returns: - tuple: - - :class:`astropy.table.Table`: Astropy Table with SkyMapper information. - - :class:`astropy.coordinates.SkyCoord`: - - .. codeauthor:: Rasmus Handberg - """ - - if isinstance(radius, (float, int)): - radius *= u.deg - - # Query the SkyMapper cone-search API: - params = { - 'RA': coo_centre.icrs.ra.deg, - 'DEC': coo_centre.icrs.dec.deg, - 'SR': Angle(radius).deg, - 'CATALOG': 'dr2.master', - 'VERB': 1, - 'RESPONSEFORMAT': 'VOTABLE' - } - res = requests.get('http://skymapper.anu.edu.au/sm-cone/public/query', params=params) - res.raise_for_status() - - # For some reason the VOTable parser needs a file-like object: - fid = BytesIO(bytes(res.text, 'utf8')) - results = Table.read(fid, format='votable') - - if len(results) == 0: - return None, None - - # Clean the results: - # http://skymapper.anu.edu.au/data-release/dr2/#Access - indx = (results['flags'] == 0) & (results['nimaflags'] == 0) & (results['ngood'] > 1) - results = results[indx] - if len(results) == 0: - return None, None - - # Create SkyCoord object containing SkyMapper objects with their observation time: - skymapper = SkyCoord( - ra=results['raj2000'], - dec=results['dej2000'], - obstime=Time(results['mean_epoch'], format='mjd', scale='utc'), - frame='icrs') - - return results, skymapper - -#-------------------------------------------------------------------------------------------------- -def query_all(coo_centre, radius=24*u.arcmin, dist_cutoff=2*u.arcsec): - """ - Query all catalogs (REFCAT2, APASS, SDSS and SkyMapper) and return merged catalog. - - Merging of catalogs are done using sky coordinates: - https://docs.astropy.org/en/stable/coordinates/matchsep.html#matching-catalogs - - Parameters: - coo_centre (:class:`astropy.coordinates.SkyCoord`): Coordinates of centre of search cone. - radius (float): Search radius. Default 24 arcmin. - dist_cutoff (float): Maximal distance between object is catalog matching. Default 2 arcsec. - - Returns: - :class:`astropy.table.Table`: Table with catalog stars. - - TODO: - - Use the overlapping magnitudes to make better matching. - - .. codeauthor:: Rasmus Handberg - .. codeauthor:: Emir Karamehmetoglu - """ - - # Query the REFCAT2 catalog using CasJobs around the target position: - results = query_casjobs_refcat2(coo_centre, radius=radius) - AT_results = Table(results) - refcat = SkyCoord(ra=AT_results['ra'], dec=AT_results['decl'], unit=u.deg, frame='icrs') - - # REFCAT results table does not have uBV - N = len(AT_results) - d = np.full(N, np.NaN) - AT_results.add_column(MaskedColumn(name='B_mag', unit='mag', dtype='float32', fill_value=np.NaN, data=d)) - AT_results.add_column(MaskedColumn(name='V_mag', unit='mag', dtype='float32', fill_value=np.NaN, data=d)) - AT_results.add_column(MaskedColumn(name='u_mag', unit='mag', dtype='float32', fill_value=np.NaN, data=d)) - - # Query APASS around the target position: - results_apass = query_apass(coo_centre, radius=radius) - if results_apass: - AT_apass = Table(results_apass) - apass = SkyCoord(ra=AT_apass['ra'], dec=AT_apass['decl'], unit=u.deg, frame='icrs') - - # Match the two catalogs using coordinates: - idx, d2d, _ = apass.match_to_catalog_sky(refcat) - sep_constraint = (d2d <= dist_cutoff) # Reject any match further away than the cutoff - idx_apass = np.arange(len(idx), dtype='int') # since idx maps apass to refcat - - # Update results table with APASS bands of interest - AT_results['B_mag'][idx[sep_constraint]] = AT_apass[idx_apass[sep_constraint]]['B_mag'] - AT_results['V_mag'][idx[sep_constraint]] = AT_apass[idx_apass[sep_constraint]]['V_mag'] - AT_results['u_mag'][idx[sep_constraint]] = AT_apass[idx_apass[sep_constraint]]['u_mag'] - - # Create SDSS around the target position: - AT_sdss, sdss = query_sdss(coo_centre, radius=radius) - if AT_sdss: - # Match to dist_cutoff sky distance (angular) apart - idx, d2d, _ = sdss.match_to_catalog_sky(refcat) - sep_constraint = (d2d <= dist_cutoff) - idx_sdss = np.arange(len(idx), dtype='int') # since idx maps sdss to refcat - - # Overwrite APASS u-band with SDSS u-band: - AT_results['u_mag'][idx[sep_constraint]] = AT_sdss[idx_sdss[sep_constraint]]['psfMag_u'] - - # Query SkyMapper around the target position, only if there are missing u-band magnitudes: - if anynan(AT_results['u_mag']): - results_skymapper, skymapper = query_skymapper(coo_centre, radius=radius) - if results_skymapper: - idx, d2d, _ = skymapper.match_to_catalog_sky(refcat) - sep_constraint = (d2d <= dist_cutoff) - idx_skymapper = np.arange(len(idx), dtype='int') # since idx maps skymapper to refcat - - newval = results_skymapper[idx_skymapper[sep_constraint]]['u_psf'] - oldval = AT_results['u_mag'][idx[sep_constraint]] - indx = ~np.isfinite(oldval) - if np.any(indx): - AT_results['u_mag'][idx[sep_constraint]][indx] = newval[indx] - - return AT_results - -#-------------------------------------------------------------------------------------------------- + logger = logging.getLogger(__name__) + + # Just put in a stop criterion to avoid infinite recursion: + if radius < 0.04 * u.deg: + raise Exception("Too many subdivides") + + # Search central cone: + try: + results = _query_casjobs_refcat2(coo_centre, radius=0.5 * radius) + except CasjobsMemoryError: + logger.debug("CasJobs failed with memory error. Trying to use smaller radii.") + results = _query_casjobs_refcat2_divide_and_conquer(coo_centre, radius=0.5 * radius) + + # Search six cones around central cone: + for n in range(6): + # FIXME: The 0.8 here is kind of a guess. There should be an analytic solution + new = SkyCoord(ra=coo_centre.ra.deg + 0.8 * Angle(radius).deg * np.cos(n * 60 * np.pi / 180), + dec=coo_centre.dec.deg + 0.8 * Angle(radius).deg * np.sin(n * 60 * np.pi / 180), unit='deg', + frame='icrs') + + try: + results += _query_casjobs_refcat2(new, radius=0.5 * radius) + except CasjobsMemoryError: + logger.debug("CasJobs failed with memory error. Trying to use smaller radii.") + results += _query_casjobs_refcat2_divide_and_conquer(new, radius=0.5 * radius) + + return results + + +# -------------------------------------------------------------------------------------------------- +def _query_casjobs_refcat2(coo_centre, radius=24 * u.arcmin): + """ + Uses the CasJobs program to do a cone-search around the position. + + Parameters: + coo_centre (:class:`astropy.coordinates.SkyCoord`): Coordinates of centre of search cone. + radius (Angle, optional): Search radius. Default is 24 arcmin. + + Returns: + list: List of dicts with the REFCAT2 information. + + .. codeauthor:: Rasmus Handberg + """ + + logger = logging.getLogger(__name__) + if isinstance(radius, (float, int)): + radius *= u.deg + + sql = "SELECT r.* FROM fGetNearbyObjEq({ra:f}, {dec:f}, {radius:f}) AS n INNER JOIN HLSP_ATLAS_REFCAT2.refcat2 AS r ON n.objid=r.objid ORDER BY n.distance;".format( + ra=coo_centre.ra.deg, dec=coo_centre.dec.deg, radius=Angle(radius).deg) + logger.debug(sql) + + # Make sure that CasJobs have been configured: + configure_casjobs() + + # The command to run the casjobs script: + # BEWARE: This may change in the future without warning - it has before! + cmd = 'java -jar casjobs.jar execute "{0:s}"'.format(sql) + + # Execute the command: + cmd = shlex.split(cmd) + directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'casjobs') + proc = subprocess.Popen(cmd, cwd=directory, stdout=subprocess.PIPE, universal_newlines=True) + stdout, stderr = proc.communicate() + output = stdout.split("\n") + + # build list of all kois from output from the CasJobs-script: + error_thrown = False + results = [] + for line in output: + line = line.strip() + if line == '': + continue + if 'ERROR' in line: + error_thrown = True + break + + row = line.split(',') + if len(row) == 45 and row[0] != '[objid]:Integer': + results.append( + {'starid': int(row[0]), 'ra': floatval(row[1]), 'decl': floatval(row[2]), 'pm_ra': floatval(row[5]), + 'pm_dec': floatval(row[7]), 'gaia_mag': floatval(row[9]), 'gaia_bp_mag': floatval(row[11]), + 'gaia_rp_mag': floatval(row[13]), 'gaia_variability': intval(row[17]), 'g_mag': floatval(row[22]), + 'r_mag': floatval(row[26]), 'i_mag': floatval(row[30]), 'z_mag': floatval(row[34]), + 'J_mag': floatval(row[39]), 'H_mag': floatval(row[41]), 'K_mag': floatval(row[43]), }) + + if error_thrown: + error_msg = '' + for line in output: + if len(line.strip()) > 0: + error_msg += line.strip() + "\n" + + logger.debug("Error Msg: %s", error_msg) + if 'query results exceed memory limit' in error_msg.lower(): + raise CasjobsMemoryError("Query results exceed memory limit") + else: + raise CasjobsError("ERROR detected in CasJobs: " + error_msg) + + if not results: + raise CasjobsError("Could not find anything on CasJobs") + + logger.debug("Found %d results", len(results)) + return results + + +# -------------------------------------------------------------------------------------------------- +def query_apass(coo_centre, radius=24 * u.arcmin): + """ + Queries APASS catalog using cone-search around the position. + + Parameters: + coo_centre (:class:`astropy.coordinates.SkyCoord`): Coordinates of centre of search cone. + radius (float, optional): + + Returns: + list: List of dicts with the APASS information. + + .. codeauthor:: Rasmus Handberg + """ + + # https://vizier.u-strasbg.fr/viz-bin/VizieR-3?-source=II/336 + + if isinstance(radius, (float, int)): + radius *= u.deg + + data = {'ra': coo_centre.icrs.ra.deg, 'dec': coo_centre.icrs.dec.deg, 'radius': Angle(radius).deg, 'outtype': '1'} + + res = requests.post('https://www.aavso.org/cgi-bin/apass_dr10_download.pl', data=data) + res.raise_for_status() + + results = [] + + lines = res.text.split("\n") + # header = lines[0] + + for line in lines[1:]: + if line.strip() == '': continue + row = line.strip().split(',') + + results.append( + {'ra': floatval(row[0]), 'decl': floatval(row[2]), 'V_mag': floatval(row[4]), 'B_mag': floatval(row[7]), + 'u_mag': floatval(row[10]), 'g_mag': floatval(row[13]), 'r_mag': floatval(row[16]), + 'i_mag': floatval(row[19]), 'z_mag': floatval(row[22]), 'Y_mag': floatval(row[25])}) + + return results + + +# -------------------------------------------------------------------------------------------------- +def query_sdss(coo_centre, radius=24 * u.arcmin, dr=16, clean=True): + """ + Queries SDSS catalog using cone-search around the position. + + Parameters: + coo_centre (:class:`astropy.coordinates.SkyCoord`): Coordinates of centre of search cone. + radius (float, optional): + dr (int, optional): SDSS Data Release to query. Default=16. + clean (bool, optional): Clean results for stars only and no other problems. + + Returns: + tuple: + - :class:`astropy.table.Table`: Table with SDSS information. + - :class:`astropy.coordinates.SkyCoord`: Sky coordinates for SDSS objects. + + .. codeauthor:: Emir Karamehmetoglu + .. codeauthor:: Rasmus Handberg + """ + + if isinstance(radius, (float, int)): + radius *= u.deg + + AT_sdss = SDSS.query_region(coo_centre, photoobj_fields=['type', 'clean', 'ra', 'dec', 'psfMag_u'], data_release=dr, + timeout=600, radius=radius) + + if AT_sdss is None: + return None, None + + if clean: + # Clean SDSS following https://www.sdss.org/dr12/algorithms/photo_flags_recommend/ + # 6 == star, clean means remove interp, edge, suspicious defects, deblending problems, duplicates. + AT_sdss = AT_sdss[(AT_sdss['type'] == 6) & (AT_sdss['clean'] == 1)] + + # Remove these columns since they are no longer needed: + AT_sdss.remove_columns(['type', 'clean']) + + if len(AT_sdss) == 0: + return None, None + + # Create SkyCoord object with the coordinates: + sdss = SkyCoord(ra=AT_sdss['ra'], dec=AT_sdss['dec'], unit=u.deg, frame='icrs') + + return AT_sdss, sdss + + +# -------------------------------------------------------------------------------------------------- +def query_simbad(coo_centre, radius=24 * u.arcmin): + """ + Query SIMBAD using cone-search around the position using astroquery. + + Parameters: + coo_centre (:class:`astropy.coordinates.SkyCoord`): Coordinates of centre of search cone. + radius (float, optional): + + Returns: + list: Astropy Table with SIMBAD information. + + .. codeauthor:: Rasmus Handberg + """ + + s = Simbad() + s.ROW_LIMIT = 0 + s.remove_votable_fields('coordinates') + s.add_votable_fields('ra(d;A;ICRS;J2000)', 'dec(d;D;ICRS;2000)', 'pmra', 'pmdec') + s.add_votable_fields('otype') + s.add_votable_fields('flux(B)', 'flux(V)', 'flux(R)', 'flux(I)', 'flux(J)', 'flux(H)', 'flux(K)') + s.add_votable_fields('flux(u)', 'flux(g)', 'flux(r)', 'flux(i)', 'flux(z)') + + rad = Angle(radius).arcmin + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', category=UserWarning) + results = s.query_criteria( + f'region(circle, icrs, {coo_centre.icrs.ra.deg:.10f} {coo_centre.icrs.dec.deg:+.10f}, {rad}m)', + otypes='Star') + + if not results: + return None, None + + # Rename columns: + results.rename_column('MAIN_ID', 'main_id') + results.rename_column('RA_d_A_ICRS_J2000', 'ra') + results.rename_column('DEC_d_D_ICRS_2000', 'dec') + results.rename_column('PMRA', 'pmra') + results.rename_column('PMDEC', 'pmdec') + results.rename_column('FLUX_B', 'B_mag') + results.rename_column('FLUX_V', 'V_mag') + results.rename_column('FLUX_R', 'R_mag') + results.rename_column('FLUX_I', 'I_mag') + results.rename_column('FLUX_J', 'J_mag') + results.rename_column('FLUX_H', 'H_mag') + results.rename_column('FLUX_K', 'K_mag') + results.rename_column('FLUX_u', 'u_mag') + results.rename_column('FLUX_g', 'g_mag') + results.rename_column('FLUX_r', 'r_mag') + results.rename_column('FLUX_i', 'i_mag') + results.rename_column('FLUX_z', 'z_mag') + results.rename_column('OTYPE', 'otype') + results.remove_column('SCRIPT_NUMBER_ID') + results.sort(['V_mag', 'B_mag', 'H_mag']) + + # Filter out object types which shouldn'r really be in there anyway: + indx = (results['otype'] == 'Galaxy') | (results['otype'] == 'LINER') | (results['otype'] == 'SN') + results = results[~indx] + + if len(results) == 0: + return None, None + + # Build sky coordinates object: + simbad = SkyCoord(ra=results['ra'], dec=results['dec'], pm_ra_cosdec=results['pmra'], pm_dec=results['pmdec'], + frame='icrs', obstime='J2000') + + return results, simbad + + +# -------------------------------------------------------------------------------------------------- +def query_skymapper(coo_centre, radius=24 * u.arcmin): + """ + Queries SkyMapper catalog using cone-search around the position. + + Parameters: + coo_centre (:class:`astropy.coordinates.SkyCoord`): Coordinates of centre of search cone. + radius (float, optional): + + Returns: + tuple: + - :class:`astropy.table.Table`: Astropy Table with SkyMapper information. + - :class:`astropy.coordinates.SkyCoord`: + + .. codeauthor:: Rasmus Handberg + """ + + if isinstance(radius, (float, int)): + radius *= u.deg + + # Query the SkyMapper cone-search API: + params = {'RA': coo_centre.icrs.ra.deg, 'DEC': coo_centre.icrs.dec.deg, 'SR': Angle(radius).deg, + 'CATALOG': 'dr2.master', 'VERB': 1, 'RESPONSEFORMAT': 'VOTABLE'} + res = requests.get('http://skymapper.anu.edu.au/sm-cone/public/query', params=params) + res.raise_for_status() + + # For some reason the VOTable parser needs a file-like object: + fid = BytesIO(bytes(res.text, 'utf8')) + results = Table.read(fid, format='votable') + + if len(results) == 0: + return None, None + + # Clean the results: + # http://skymapper.anu.edu.au/data-release/dr2/#Access + indx = (results['flags'] == 0) & (results['nimaflags'] == 0) & (results['ngood'] > 1) + results = results[indx] + if len(results) == 0: + return None, None + + # Create SkyCoord object containing SkyMapper objects with their observation time: + skymapper = SkyCoord(ra=results['raj2000'], dec=results['dej2000'], + obstime=Time(results['mean_epoch'], format='mjd', scale='utc'), frame='icrs') + + return results, skymapper + + +# -------------------------------------------------------------------------------------------------- +def query_all(coo_centre, radius=24 * u.arcmin, dist_cutoff=2 * u.arcsec): + """ + Query all catalogs (REFCAT2, APASS, SDSS and SkyMapper) and return merged catalog. + + Merging of catalogs are done using sky coordinates: + https://docs.astropy.org/en/stable/coordinates/matchsep.html#matching-catalogs + + Parameters: + coo_centre (:class:`astropy.coordinates.SkyCoord`): Coordinates of centre of search cone. + radius (float): Search radius. Default 24 arcmin. + dist_cutoff (float): Maximal distance between object is catalog matching. Default 2 arcsec. + + Returns: + :class:`astropy.table.Table`: Table with catalog stars. + + TODO: + - Use the overlapping magnitudes to make better matching. + + .. codeauthor:: Rasmus Handberg + .. codeauthor:: Emir Karamehmetoglu + """ + + # Query the REFCAT2 catalog using CasJobs around the target position: + results = query_casjobs_refcat2(coo_centre, radius=radius) + AT_results = Table(results) + refcat = SkyCoord(ra=AT_results['ra'], dec=AT_results['decl'], unit=u.deg, frame='icrs') + + # REFCAT results table does not have uBV + N = len(AT_results) + d = np.full(N, np.NaN) + AT_results.add_column(MaskedColumn(name='B_mag', unit='mag', dtype='float32', fill_value=np.NaN, data=d)) + AT_results.add_column(MaskedColumn(name='V_mag', unit='mag', dtype='float32', fill_value=np.NaN, data=d)) + AT_results.add_column(MaskedColumn(name='u_mag', unit='mag', dtype='float32', fill_value=np.NaN, data=d)) + + # Query APASS around the target position: + results_apass = query_apass(coo_centre, radius=radius) + if results_apass: + AT_apass = Table(results_apass) + apass = SkyCoord(ra=AT_apass['ra'], dec=AT_apass['decl'], unit=u.deg, frame='icrs') + + # Match the two catalogs using coordinates: + idx, d2d, _ = apass.match_to_catalog_sky(refcat) + sep_constraint = (d2d <= dist_cutoff) # Reject any match further away than the cutoff + idx_apass = np.arange(len(idx), dtype='int') # since idx maps apass to refcat + + # Update results table with APASS bands of interest + AT_results['B_mag'][idx[sep_constraint]] = AT_apass[idx_apass[sep_constraint]]['B_mag'] + AT_results['V_mag'][idx[sep_constraint]] = AT_apass[idx_apass[sep_constraint]]['V_mag'] + AT_results['u_mag'][idx[sep_constraint]] = AT_apass[idx_apass[sep_constraint]]['u_mag'] + + # Create SDSS around the target position: + AT_sdss, sdss = query_sdss(coo_centre, radius=radius) + if AT_sdss: + # Match to dist_cutoff sky distance (angular) apart + idx, d2d, _ = sdss.match_to_catalog_sky(refcat) + sep_constraint = (d2d <= dist_cutoff) + idx_sdss = np.arange(len(idx), dtype='int') # since idx maps sdss to refcat + + # Overwrite APASS u-band with SDSS u-band: + AT_results['u_mag'][idx[sep_constraint]] = AT_sdss[idx_sdss[sep_constraint]]['psfMag_u'] + + # Query SkyMapper around the target position, only if there are missing u-band magnitudes: + if anynan(AT_results['u_mag']): + results_skymapper, skymapper = query_skymapper(coo_centre, radius=radius) + if results_skymapper: + idx, d2d, _ = skymapper.match_to_catalog_sky(refcat) + sep_constraint = (d2d <= dist_cutoff) + idx_skymapper = np.arange(len(idx), dtype='int') # since idx maps skymapper to refcat + + newval = results_skymapper[idx_skymapper[sep_constraint]]['u_psf'] + oldval = AT_results['u_mag'][idx[sep_constraint]] + indx = ~np.isfinite(oldval) + if np.any(indx): + AT_results['u_mag'][idx[sep_constraint]][indx] = newval[indx] + + return AT_results + + +# -------------------------------------------------------------------------------------------------- def convert_table_to_dict(tab): - """ - Utility function for converting Astropy Table to list of dicts that the database - likes as input. - - Parameters: - tab (:class:`astropy.table.Table`): Astropy table coming from query_all. - - Returns: - list: List of dicts where the column names are the keys. Datatypes are changed - to things that the database understands (e.g. NaN -> None). - - .. codeauthor:: Rasmus Handberg - """ - results = [dict(zip(tab.colnames, row)) for row in tab.filled()] - for row in results: - for key, val in row.items(): - if isinstance(val, (np.int64, np.int32)): - row[key] = int(val) - elif isinstance(val, (float, np.float32, np.float64)): - if np.isfinite(val): - row[key] = float(val) - else: - row[key] = None - - return results - -#-------------------------------------------------------------------------------------------------- -def download_catalog(target=None, radius=24*u.arcmin, radius_ztf=3*u.arcsec, - dist_cutoff=2*u.arcsec, update_existing=False): - """ - Download reference star catalogs and save to Flows database. - - Parameters: - target (str or int): Target identifier to download catalog for. - radius (Angle, optional): Radius around target to download catalogs. - radius_ztf (Angle, optional): Radius around target to search for ZTF identifier. - dist_cutoff (Angle, optional): Distance cutoff used for matching catalog positions. - update_existing (bool, optional): Update existing catalog entries or skip them. - - .. codeauthor:: Rasmus Handberg - """ - - logger = logging.getLogger(__name__) - - with AADC_DB() as db: - - # Get the information about the target from the database: - if target is not None and isinstance(target, (int, float)): - db.cursor.execute("SELECT targetid,target_name,ra,decl,discovery_date FROM flows.targets WHERE targetid=%s;", [int(target)]) - elif target is not None: - db.cursor.execute("SELECT targetid,target_name,ra,decl,discovery_date FROM flows.targets WHERE target_name=%s;", [target]) - else: - db.cursor.execute("SELECT targetid,target_name,ra,decl,discovery_date FROM flows.targets WHERE catalog_downloaded=FALSE;") - - for row in db.cursor.fetchall(): - # The unique identifier of the target: - targetid = int(row['targetid']) - target_name = row['target_name'] - dd = row['discovery_date'] - if dd is not None: - dd = Time(dd, format='datetime', scale='utc') - - # Coordinate of the target, which is the centre of the search cone: - coo_centre = SkyCoord(ra=row['ra'], dec=row['decl'], unit=u.deg, frame='icrs') - - # Download combined catalog from all sources: - tab = query_all(coo_centre, radius=radius, dist_cutoff=dist_cutoff) - - # Query for a ZTF identifier for this target: - ztf_id = query_ztf_id(coo_centre, radius=radius_ztf, discovery_date=dd) - - # Because the database is picky with datatypes, we need to change things - # before they are passed on to the database: - results = convert_table_to_dict(tab) - - # Insert the catalog into the local database: - if update_existing: - on_conflict = """ON CONSTRAINT refcat2_pkey DO UPDATE SET + """ + Utility function for converting Astropy Table to list of dicts that the database + likes as input. + + Parameters: + tab (:class:`astropy.table.Table`): Astropy table coming from query_all. + + Returns: + list: List of dicts where the column names are the keys. Datatypes are changed + to things that the database understands (e.g. NaN -> None). + + .. codeauthor:: Rasmus Handberg + """ + results = [dict(zip(tab.colnames, row)) for row in tab.filled()] + for row in results: + for key, val in row.items(): + if isinstance(val, (np.int64, np.int32)): + row[key] = int(val) + elif isinstance(val, (float, np.float32, np.float64)): + if np.isfinite(val): + row[key] = float(val) + else: + row[key] = None + + return results + + +# -------------------------------------------------------------------------------------------------- +def download_catalog(target=None, radius=24 * u.arcmin, radius_ztf=3 * u.arcsec, dist_cutoff=2 * u.arcsec, + update_existing=False): + """ + Download reference star catalogs and save to Flows database. + + Parameters: + target (str or int): Target identifier to download catalog for. + radius (Angle, optional): Radius around target to download catalogs. + radius_ztf (Angle, optional): Radius around target to search for ZTF identifier. + dist_cutoff (Angle, optional): Distance cutoff used for matching catalog positions. + update_existing (bool, optional): Update existing catalog entries or skip them. + + .. codeauthor:: Rasmus Handberg + """ + + logger = logging.getLogger(__name__) + + with AADC_DB() as db: + + # Get the information about the target from the database: + if target is not None and isinstance(target, (int, float)): + db.cursor.execute( + "SELECT targetid,target_name,ra,decl,discovery_date FROM flows.targets WHERE targetid=%s;", + [int(target)]) + elif target is not None: + db.cursor.execute( + "SELECT targetid,target_name,ra,decl,discovery_date FROM flows.targets WHERE target_name=%s;", [target]) + else: + db.cursor.execute( + "SELECT targetid,target_name,ra,decl,discovery_date FROM flows.targets WHERE catalog_downloaded=FALSE;") + + for row in db.cursor.fetchall(): + # The unique identifier of the target: + targetid = int(row['targetid']) + target_name = row['target_name'] + dd = row['discovery_date'] + if dd is not None: + dd = Time(dd, format='datetime', scale='utc') + + # Coordinate of the target, which is the centre of the search cone: + coo_centre = SkyCoord(ra=row['ra'], dec=row['decl'], unit=u.deg, frame='icrs') + + # Download combined catalog from all sources: + tab = query_all(coo_centre, radius=radius, dist_cutoff=dist_cutoff) + + # Query for a ZTF identifier for this target: + ztf_id = query_ztf_id(coo_centre, radius=radius_ztf, discovery_date=dd) + + # Because the database is picky with datatypes, we need to change things + # before they are passed on to the database: + results = convert_table_to_dict(tab) + + # Insert the catalog into the local database: + if update_existing: + on_conflict = """ON CONSTRAINT refcat2_pkey DO UPDATE SET ra=EXCLUDED.ra, decl=EXCLUDED.decl, pm_ra=EXCLUDED.pm_ra, @@ -650,11 +620,11 @@ def download_catalog(target=None, radius=24*u.arcmin, radius_ztf=3*u.arcsec, "V_mag"=EXCLUDED."V_mag", "B_mag"=EXCLUDED."B_mag" WHERE refcat2.starid=EXCLUDED.starid""" - else: - on_conflict = 'DO NOTHING' + else: + on_conflict = 'DO NOTHING' - try: - db.cursor.executemany("""INSERT INTO flows.refcat2 ( + try: + db.cursor.executemany("""INSERT INTO flows.refcat2 ( starid, ra, decl, @@ -695,11 +665,12 @@ def download_catalog(target=None, radius=24*u.arcmin, radius_ztf=3*u.arcsec, %(V_mag)s, %(B_mag)s) ON CONFLICT """ + on_conflict + ";", results) - logger.info("%d catalog entries inserted for %s.", db.cursor.rowcount, target_name) - - # Mark the target that the catalog has been downloaded: - db.cursor.execute("UPDATE flows.targets SET catalog_downloaded=TRUE,ztf_id=%s WHERE targetid=%s;", (ztf_id, targetid)) - db.conn.commit() - except: # noqa: E722, pragma: no cover - db.conn.rollback() - raise + logger.info("%d catalog entries inserted for %s.", db.cursor.rowcount, target_name) + + # Mark the target that the catalog has been downloaded: + db.cursor.execute("UPDATE flows.targets SET catalog_downloaded=TRUE,ztf_id=%s WHERE targetid=%s;", + (ztf_id, targetid)) + db.conn.commit() + except: # noqa: E722, pragma: no cover + db.conn.rollback() + raise diff --git a/flows/config.py b/flows/config.py deleted file mode 100644 index eb9d8af..0000000 --- a/flows/config.py +++ /dev/null @@ -1,30 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" - -.. codeauthor:: Rasmus Handberg -""" - -import os.path -import configparser -from functools import lru_cache - -#-------------------------------------------------------------------------------------------------- -@lru_cache(maxsize=1) -def load_config(): - """ - Load configuration file. - - Returns: - ``configparser.ConfigParser``: Configuration file. - - .. codeauthor:: Rasmus Handberg - """ - - config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'config.ini') - if not os.path.isfile(config_file): - raise FileNotFoundError("config.ini file not found") - - config = configparser.ConfigParser() - config.read(config_file) - return config diff --git a/flows/coordinatematch/coordinatematch.py b/flows/coordinatematch/coordinatematch.py index 97a2acc..55f73eb 100644 --- a/flows/coordinatematch/coordinatematch.py +++ b/flows/coordinatematch/coordinatematch.py @@ -13,378 +13,331 @@ from networkx import Graph, connected_components from .wcs import WCS2 -#-------------------------------------------------------------------------------------------------- + +# -------------------------------------------------------------------------------------------------- class CoordinateMatch(object): - def __init__(self, - xy, - rd, - xy_order=None, - rd_order=None, - xy_nmax=None, - rd_nmax=None, - n_triangle_packages=10, - triangle_package_size=10000, - maximum_angle_distance=0.001, - distance_factor=1): - - self.xy, self.rd = np.array(xy), np.array(rd) - - self._xy = xy - np.mean(xy, axis=0) - self._rd = rd - np.mean(rd, axis=0) - self._rd[:, 0] *= np.cos(np.deg2rad(self.rd[:, 1])) - - xy_n, rd_n = min(xy_nmax, len(xy)), min(rd_nmax, len(rd)) - - self.i_xy = xy_order[:xy_n] if xy_order is not None else np.arange( - xy_n) - self.i_rd = rd_order[:rd_n] if rd_order is not None else np.arange( - rd_n) - - self.n_triangle_packages = n_triangle_packages - self.triangle_package_size = triangle_package_size - - self.maximum_angle_distance = maximum_angle_distance - self.distance_factor = distance_factor - - self.triangle_package_generator = self._sorted_triangle_packages() - - self.i_xy_triangles = list() - self.i_rd_triangles = list() - self.parameters = None - self.neighbours = Graph() - - self.normalizations = type( - 'Normalizations', (object, ), - dict(ra=0.0001, dec=0.0001, scale=0.002, angle=0.002)) - - self.bounds = type( - 'Bounds', (object, ), - dict(xy=self.xy.mean(axis=0), - rd=None, - radius=None, - scale=None, - angle=None)) - - #---------------------------------------------------------------------------------------------- - def set_normalizations(self, ra=None, dec=None, scale=None, angle=None): - """ - Set normalization factors in the (ra, dec, scale, angle) space. - - Defaults are: - ra = 0.0001 degrees - dec = 0.0001 degrees - scale = 0.002 log(arcsec/pixel) - angle = 0.002 radians - """ - - if self.parameters is not None: - raise RuntimeError("can't change normalization after matching is started") - - # TODO: Dont use "assert" here - raise ValueError instead - assert ra is None or 0 < ra - assert dec is None or 0 < dec - assert scale is None or 0 < scale - assert angle is None or 0 < angle - - self.normalizations.ra = ra if ra is not None else self.normalizations.ra - self.normalizations.dec = dec if dec is not None else self.normalizations.dec - self.normalizations.scale = scale if scale is not None else self.normalizations.scale - self.normalizations.angle = angle if ra is not None else self.normalizations.angle - - #---------------------------------------------------------------------------------------------- - def set_bounds(self, x=None, y=None, ra=None, dec=None, radius=None, scale=None, angle=None): - """ - Set bounds for what are valid results. - - Set x, y, ra, dec and radius to specify that the x, y coordinates must be no - further that the radius [degrees] away from the ra, dec coordinates. - Set upper and lower bounds on the scale [log(arcsec/pixel)] and/or the angle - [radians] if those are known, possibly from previous observations with the - same system. - """ - - if self.parameters is not None: - raise RuntimeError("can't change bounds after matching is started") - - if [x, y, ra, dec, radius].count(None) == 5: - # TODO: Dont use "assert" here - raise ValueError instead - assert 0 <= ra < 360 - assert -180 <= dec <= 180 - assert 0 < radius - - self.bounds.xy = x, y - self.bounds.rd = ra, dec - self.bounds.radius = radius - - elif [x, y, ra, dec, radius].count(None) > 0: - raise ValueError('x, y, ra, dec and radius must all be specified') - - # TODO: Dont use "assert" here - raise ValueError instead - assert scale is None or 0 < scale[0] < scale[1] - assert angle is None or -np.pi <= angle[0] < angle[1] <= np.pi - - self.bounds.scale = scale if scale is not None else self.bounds.scale - self.bounds.angle = angle if angle is not None else self.bounds.angle + def __init__(self, xy, rd, xy_order=None, rd_order=None, xy_nmax=None, rd_nmax=None, n_triangle_packages=10, + triangle_package_size=10000, maximum_angle_distance=0.001, distance_factor=1): + + self.xy, self.rd = np.array(xy), np.array(rd) + + self._xy = xy - np.mean(xy, axis=0) + self._rd = rd - np.mean(rd, axis=0) + self._rd[:, 0] *= np.cos(np.deg2rad(self.rd[:, 1])) + + xy_n, rd_n = min(xy_nmax, len(xy)), min(rd_nmax, len(rd)) + + self.i_xy = xy_order[:xy_n] if xy_order is not None else np.arange(xy_n) + self.i_rd = rd_order[:rd_n] if rd_order is not None else np.arange(rd_n) + + self.n_triangle_packages = n_triangle_packages + self.triangle_package_size = triangle_package_size + + self.maximum_angle_distance = maximum_angle_distance + self.distance_factor = distance_factor + + self.triangle_package_generator = self._sorted_triangle_packages() + + self.i_xy_triangles = list() + self.i_rd_triangles = list() + self.parameters = None + self.neighbours = Graph() + + self.normalizations = type('Normalizations', (object,), dict(ra=0.0001, dec=0.0001, scale=0.002, angle=0.002)) + + self.bounds = type('Bounds', (object,), + dict(xy=self.xy.mean(axis=0), rd=None, radius=None, scale=None, angle=None)) + + # ---------------------------------------------------------------------------------------------- + def set_normalizations(self, ra=None, dec=None, scale=None, angle=None): + """ + Set normalization factors in the (ra, dec, scale, angle) space. + + Defaults are: + ra = 0.0001 degrees + dec = 0.0001 degrees + scale = 0.002 log(arcsec/pixel) + angle = 0.002 radians + """ + + if self.parameters is not None: + raise RuntimeError("can't change normalization after matching is started") + + # TODO: Dont use "assert" here - raise ValueError instead + assert ra is None or 0 < ra + assert dec is None or 0 < dec + assert scale is None or 0 < scale + assert angle is None or 0 < angle + + self.normalizations.ra = ra if ra is not None else self.normalizations.ra + self.normalizations.dec = dec if dec is not None else self.normalizations.dec + self.normalizations.scale = scale if scale is not None else self.normalizations.scale + self.normalizations.angle = angle if ra is not None else self.normalizations.angle + + # ---------------------------------------------------------------------------------------------- + def set_bounds(self, x=None, y=None, ra=None, dec=None, radius=None, scale=None, angle=None): + """ + Set bounds for what are valid results. + + Set x, y, ra, dec and radius to specify that the x, y coordinates must be no + further that the radius [degrees] away from the ra, dec coordinates. + Set upper and lower bounds on the scale [log(arcsec/pixel)] and/or the angle + [radians] if those are known, possibly from previous observations with the + same system. + """ + + if self.parameters is not None: + raise RuntimeError("can't change bounds after matching is started") + + if [x, y, ra, dec, radius].count(None) == 5: + # TODO: Dont use "assert" here - raise ValueError instead + assert 0 <= ra < 360 + assert -180 <= dec <= 180 + assert 0 < radius + + self.bounds.xy = x, y + self.bounds.rd = ra, dec + self.bounds.radius = radius + + elif [x, y, ra, dec, radius].count(None) > 0: + raise ValueError('x, y, ra, dec and radius must all be specified') + + # TODO: Dont use "assert" here - raise ValueError instead + assert scale is None or 0 < scale[0] < scale[1] + assert angle is None or -np.pi <= angle[0] < angle[1] <= np.pi + + self.bounds.scale = scale if scale is not None else self.bounds.scale + self.bounds.angle = angle if angle is not None else self.bounds.angle + + # ---------------------------------------------------------------------------------------------- + def _sorted_triangles(self, pool): + for i, c in enumerate(pool): + for i, b in enumerate(pool[:i]): + for a in pool[:i]: + yield a, b, c + + # ---------------------------------------------------------------------------------------------- + def _sorted_product_pairs(self, p, q): + i_p = np.argsort(np.arange(len(p))) + i_q = np.argsort(np.arange(len(q))) + for _i_p, _i_q in sorted(product(i_p, i_q), key=lambda idxs: sum(idxs)): + yield p[_i_p], q[_i_q] + + # ---------------------------------------------------------------------------------------------- + def _sorted_triangle_packages(self): + + i_xy_triangle_generator = self._sorted_triangles(self.i_xy) + i_rd_triangle_generator = self._sorted_triangles(self.i_rd) + + i_xy_triangle_slice_generator = (tuple(islice(i_xy_triangle_generator, self.triangle_package_size)) for _ in + count()) + i_rd_triangle_slice_generator = (list(islice(i_rd_triangle_generator, self.triangle_package_size)) for _ in + count()) + + for n in count(step=self.n_triangle_packages): + + i_xy_triangle_slice = tuple(filter(None, islice(i_xy_triangle_slice_generator, self.n_triangle_packages))) + i_rd_triangle_slice = tuple(filter(None, islice(i_rd_triangle_slice_generator, self.n_triangle_packages))) + + if not len(i_xy_triangle_slice) and not len(i_rd_triangle_slice): + return + + i_xy_triangle_generator2 = self._sorted_triangles(self.i_xy) + i_rd_triangle_generator2 = self._sorted_triangles(self.i_rd) + + i_xy_triangle_cum = filter(None, + (tuple(islice(i_xy_triangle_generator2, self.triangle_package_size)) for _ in + range(n))) + i_rd_triangle_cum = filter(None, + (tuple(islice(i_rd_triangle_generator2, self.triangle_package_size)) for _ in + range(n))) + + for i_xy_triangles, i_rd_triangles in chain(filter(None, chain(*zip_longest( # alternating chain + product(i_xy_triangle_slice, i_rd_triangle_cum), product(i_xy_triangle_cum, i_rd_triangle_slice)))), + self._sorted_product_pairs(i_xy_triangle_slice, + i_rd_triangle_slice)): + yield np.array(i_xy_triangles), np.array(i_rd_triangles) + + # ---------------------------------------------------------------------------------------------- + def _get_triangle_angles(self, triangles): - #---------------------------------------------------------------------------------------------- - def _sorted_triangles(self, pool): - for i, c in enumerate(pool): - for i, b in enumerate(pool[:i]): - for a in pool[:i]: + sidelengths = np.sqrt(np.power(triangles[:, (1, 0, 0)] - triangles[:, (2, 2, 1)], 2).sum(axis=2)) - yield a, b, c + # law of cosines + angles = np.power(sidelengths[:, ((1, 2), (0, 2), (0, 1))], 2).sum(axis=2) + angles -= np.power(sidelengths[:, (0, 1, 2)], 2) + angles /= 2 * sidelengths[:, ((1, 2), (0, 2), (0, 1))].prod(axis=2) - #---------------------------------------------------------------------------------------------- - def _sorted_product_pairs(self, p, q): - i_p = np.argsort(np.arange(len(p))) - i_q = np.argsort(np.arange(len(q))) - for _i_p, _i_q in sorted(product(i_p, i_q), key=lambda idxs: sum(idxs)): - yield p[_i_p], q[_i_q] - - #---------------------------------------------------------------------------------------------- - def _sorted_triangle_packages(self): - - i_xy_triangle_generator = self._sorted_triangles(self.i_xy) - i_rd_triangle_generator = self._sorted_triangles(self.i_rd) - - i_xy_triangle_slice_generator = (tuple( - islice(i_xy_triangle_generator, self.triangle_package_size)) for _ in count()) - i_rd_triangle_slice_generator = (list( - islice(i_rd_triangle_generator, self.triangle_package_size)) for _ in count()) + return np.arccos(angles) - for n in count(step=self.n_triangle_packages): - - i_xy_triangle_slice = tuple( - filter( - None, - islice(i_xy_triangle_slice_generator, - self.n_triangle_packages))) - i_rd_triangle_slice = tuple( - filter( - None, - islice(i_rd_triangle_slice_generator, - self.n_triangle_packages))) + # ---------------------------------------------------------------------------------------------- + def _solve_for_matrices(self, xy_triangles, rd_triangles): - if not len(i_xy_triangle_slice) and not len(i_rd_triangle_slice): - return + n = len(xy_triangles) - i_xy_triangle_generator2 = self._sorted_triangles(self.i_xy) - i_rd_triangle_generator2 = self._sorted_triangles(self.i_rd) - - i_xy_triangle_cum = filter(None, (tuple( - islice(i_xy_triangle_generator2, self.triangle_package_size)) for _ in range(n))) - i_rd_triangle_cum = filter(None, (tuple( - islice(i_rd_triangle_generator2, self.triangle_package_size)) for _ in range(n))) + A = xy_triangles - np.mean(xy_triangles, axis=1).reshape(n, 1, 2) + b = rd_triangles - np.mean(rd_triangles, axis=1).reshape(n, 1, 2) - for i_xy_triangles, i_rd_triangles in chain(filter(None, - chain(*zip_longest( # alternating chain - product(i_xy_triangle_slice, i_rd_triangle_cum), - product(i_xy_triangle_cum, i_rd_triangle_slice)))), - self._sorted_product_pairs(i_xy_triangle_slice, i_rd_triangle_slice)): - - yield np.array(i_xy_triangles), np.array(i_rd_triangles) - - #---------------------------------------------------------------------------------------------- - def _get_triangle_angles(self, triangles): - - sidelengths = np.sqrt( - np.power(triangles[:, (1, 0, 0)] - triangles[:, (2, 2, 1)], - 2).sum(axis=2)) + matrices = [np.linalg.lstsq(Ai, bi, rcond=None)[0].T for Ai, bi in zip(A, b)] - # law of cosines - angles = np.power(sidelengths[:, ((1, 2), (0, 2), (0, 1))], - 2).sum(axis=2) - angles -= np.power(sidelengths[:, (0, 1, 2)], 2) - angles /= 2 * sidelengths[:, ((1, 2), (0, 2), (0, 1))].prod(axis=2) + return np.array(matrices) - return np.arccos(angles) + # ---------------------------------------------------------------------------------------------- + def _extract_parameters(self, xy_triangles, rd_triangles, matrices): - #---------------------------------------------------------------------------------------------- - def _solve_for_matrices(self, xy_triangles, rd_triangles): + parameters = [] + for xy_com, rd_com, matrix in zip(xy_triangles.mean(axis=1), rd_triangles.mean(axis=1), matrices): + # com -> center-of-mass - n = len(xy_triangles) + cos_dec = np.cos(np.deg2rad(rd_com[1])) + coordinates = (self.bounds.xy - xy_com).dot(matrix) + coordinates = coordinates / np.array([cos_dec, 1]) + rd_com - A = xy_triangles - np.mean(xy_triangles, axis=1).reshape(n, 1, 2) - b = rd_triangles - np.mean(rd_triangles, axis=1).reshape(n, 1, 2) + wcs = WCS2.from_matrix(*xy_com, *rd_com, matrix) - matrices = [ - np.linalg.lstsq(Ai, bi, rcond=None)[0].T for Ai, bi in zip(A, b) - ] + parameters.append((*coordinates, np.log(wcs.scale), np.deg2rad(wcs.angle))) - return np.array(matrices) + return parameters - #---------------------------------------------------------------------------------------------- - def _extract_parameters(self, xy_triangles, rd_triangles, matrices): + # ---------------------------------------------------------------------------------------------- + def _get_bounds_mask(self, parameters): - parameters = [] - for xy_com, rd_com, matrix in zip(xy_triangles.mean(axis=1), rd_triangles.mean(axis=1), matrices): - # com -> center-of-mass + i = np.ones(len(parameters), dtype=bool) + parameters = np.array(parameters) - cos_dec = np.cos(np.deg2rad(rd_com[1])) - coordinates = (self.bounds.xy - xy_com).dot(matrix) - coordinates = coordinates / np.array([cos_dec, 1]) + rd_com + if self.bounds.radius is not None: + i *= angular_separation(*np.deg2rad(self.bounds.rd), + *zip(*np.deg2rad(parameters[:, (0, 1)]))) <= np.deg2rad(self.bounds.radius) - wcs = WCS2.from_matrix(*xy_com, *rd_com, matrix) + if self.bounds.scale is not None: + i *= self.bounds.scale[0] <= parameters[:, 2] + i *= parameters[:, 2] <= self.bounds.scale[1] - parameters.append((*coordinates, np.log(wcs.scale), np.deg2rad(wcs.angle))) + if self.bounds.angle is not None: + i *= self.bounds.angle[0] <= parameters[:, 3] + i *= parameters[:, 3] <= self.bounds.angle[1] - return parameters + return i - #---------------------------------------------------------------------------------------------- - def _get_bounds_mask(self, parameters): + # ---------------------------------------------------------------------------------------------- + def __call__(self, minimum_matches=4, ratio_superiority=1, timeout=60): + """ + Start the alogrithm. - i = np.ones(len(parameters), dtype=bool) - parameters = np.array(parameters) + Can be run multiple times with different arguments to relax the + restrictions. - if self.bounds.radius is not None: - i *= angular_separation( - *np.deg2rad(self.bounds.rd), - *zip(*np.deg2rad(parameters[:, (0, 1)])) - ) <= np.deg2rad(self.bounds.radius) + Example + -------- + cm = CoordinateMatch(xy, rd) - if self.bounds.scale is not None: - i *= self.bounds.scale[0] <= parameters[:, 2] - i *= parameters[:, 2] <= self.bounds.scale[1] + lkwargs = [{ + minimum_matches = 20, + ratio_superiority = 5, + timeout = 10 + },{ + timeout = 60 + } - if self.bounds.angle is not None: - i *= self.bounds.angle[0] <= parameters[:, 3] - i *= parameters[:, 3] <= self.bounds.angle[1] + for i, kwargs in enumerate(lkwargs): + try: + i_xy, i_rd = cm(**kwargs) + except TimeoutError: + continue + except StopIteration: + print('Failed, no more stars.') + else: + print('Success with kwargs[%d].' % i) + else: + print('Failed, timeout.') + """ - return i + self.parameters = list() if self.parameters is None else self.parameters - #---------------------------------------------------------------------------------------------- - def __call__(self, minimum_matches=4, ratio_superiority=1, timeout=60): - """ - Start the alogrithm. + t0 = time.time() - Can be run multiple times with different arguments to relax the - restrictions. + while time.time() - t0 < timeout: - Example - -------- - cm = CoordinateMatch(xy, rd) + # get triangles and derive angles - lkwargs = [{ - minimum_matches = 20, - ratio_superiority = 5, - timeout = 10 - },{ - timeout = 60 - } + i_xy_triangles, i_rd_triangles = next(self.triangle_package_generator) - for i, kwargs in enumerate(lkwargs): - try: - i_xy, i_rd = cm(**kwargs) - except TimeoutError: - continue - except StopIteration: - print('Failed, no more stars.') - else: - print('Success with kwargs[%d].' % i) - else: - print('Failed, timeout.') - """ + xy_angles = self._get_triangle_angles(self._xy[i_xy_triangles]) + rd_angles = self._get_triangle_angles(self._rd[i_rd_triangles]) - self.parameters = list() if self.parameters is None else self.parameters + # sort triangle vertices based on angles - t0 = time.time() + i = np.argsort(xy_angles, axis=1) + i_xy_triangles = np.take_along_axis(i_xy_triangles, i, axis=1) + xy_angles = np.take_along_axis(xy_angles, i, axis=1) - while time.time() - t0 < timeout: + i = np.argsort(rd_angles, axis=1) + i_rd_triangles = np.take_along_axis(i_rd_triangles, i, axis=1) + rd_angles = np.take_along_axis(rd_angles, i, axis=1) - # get triangles and derive angles + # match triangles + matches = KDTree(xy_angles).query_ball_tree(KDTree(rd_angles), r=self.maximum_angle_distance) + matches = np.array([(_i_xy, _i_rd) for _i_xy, _li_rd in enumerate(matches) for _i_rd in _li_rd]) - i_xy_triangles, i_rd_triangles = next( - self.triangle_package_generator) + if not len(matches): + continue - xy_angles = self._get_triangle_angles(self._xy[i_xy_triangles]) - rd_angles = self._get_triangle_angles(self._rd[i_rd_triangles]) + i_xy_triangles = list(i_xy_triangles[matches[:, 0]]) + i_rd_triangles = list(i_rd_triangles[matches[:, 1]]) - # sort triangle vertices based on angles + # get parameters of wcs solutions + matrices = self._solve_for_matrices(self._xy[np.array(i_xy_triangles)], self._rd[np.array(i_rd_triangles)]) - i = np.argsort(xy_angles, axis=1) - i_xy_triangles = np.take_along_axis(i_xy_triangles, i, axis=1) - xy_angles = np.take_along_axis(xy_angles, i, axis=1) + parameters = self._extract_parameters(self.xy[np.array(i_xy_triangles)], self.rd[np.array(i_rd_triangles)], + matrices) - i = np.argsort(rd_angles, axis=1) - i_rd_triangles = np.take_along_axis(i_rd_triangles, i, axis=1) - rd_angles = np.take_along_axis(rd_angles, i, axis=1) + # apply bounds if any + if any([self.bounds.radius, self.bounds.scale, self.bounds.angle]): + mask = self._get_bounds_mask(parameters) - # match triangles - matches = KDTree(xy_angles).query_ball_tree( - KDTree(rd_angles), r=self.maximum_angle_distance) - matches = np.array([(_i_xy, _i_rd) for _i_xy, _li_rd in enumerate(matches) for _i_rd in _li_rd]) + i_xy_triangles = np.array(i_xy_triangles)[mask].tolist() + i_rd_triangles = np.array(i_rd_triangles)[mask].tolist() + parameters = np.array(parameters)[mask].tolist() - if not len(matches): - continue + # normalize parameters + normalization = [getattr(self.normalizations, v) for v in ('ra', 'dec', 'scale', 'angle')] + normalization[0] *= np.cos(np.deg2rad(self.rd[:, 1].mean(axis=0))) + parameters = list(parameters / np.array(normalization)) - i_xy_triangles = list(i_xy_triangles[matches[:, 0]]) - i_rd_triangles = list(i_rd_triangles[matches[:, 1]]) + # match parameters + neighbours = KDTree(parameters).query_ball_tree(KDTree(self.parameters + parameters), + r=self.distance_factor) + neighbours = np.array([(i, j) for i, lj in enumerate(neighbours, len(self.parameters)) for j in lj]) + neighbours = list(neighbours[(np.diff(neighbours, axis=1) < 0).flatten()]) - # get parameters of wcs solutions - matrices = self._solve_for_matrices( - self._xy[np.array(i_xy_triangles)], - self._rd[np.array(i_rd_triangles)]) + if not len(neighbours): + continue - parameters = self._extract_parameters( - self.xy[np.array(i_xy_triangles)], - self.rd[np.array(i_rd_triangles)], matrices) - - # apply bounds if any - if any([self.bounds.radius, self.bounds.scale, self.bounds.angle]): - - mask = self._get_bounds_mask(parameters) - - i_xy_triangles = np.array(i_xy_triangles)[mask].tolist() - i_rd_triangles = np.array(i_rd_triangles)[mask].tolist() - parameters = np.array(parameters)[mask].tolist() - - # normalize parameters - normalization = [ - getattr(self.normalizations, v) - for v in ('ra', 'dec', 'scale', 'angle') - ] - normalization[0] *= np.cos(np.deg2rad(self.rd[:, 1].mean(axis=0))) - parameters = list(parameters / np.array(normalization)) + self.i_xy_triangles += i_xy_triangles + self.i_rd_triangles += i_rd_triangles + self.parameters += parameters + self.neighbours.add_edges_from(neighbours) - # match parameters - neighbours = KDTree(parameters).query_ball_tree( - KDTree(self.parameters + parameters), r=self.distance_factor) - neighbours = np.array([ - (i, j) for i, lj in enumerate(neighbours, len(self.parameters)) - for j in lj - ]) - neighbours = list( - neighbours[(np.diff(neighbours, axis=1) < 0).flatten()]) - - if not len(neighbours): - continue - - self.i_xy_triangles += i_xy_triangles - self.i_rd_triangles += i_rd_triangles - self.parameters += parameters - self.neighbours.add_edges_from(neighbours) - - # get largest neighborhood - communities = list(connected_components(self.neighbours)) - c1 = np.array(list(max(communities, key=len))) - i = np.unique(np.array(self.i_xy_triangles)[c1].flatten(), - return_index=True)[1] + # get largest neighborhood + communities = list(connected_components(self.neighbours)) + c1 = np.array(list(max(communities, key=len))) + i = np.unique(np.array(self.i_xy_triangles)[c1].flatten(), return_index=True)[1] - if ratio_superiority > 1 and len(communities) > 1: - communities.remove(set(c1)) - c2 = np.array(list(max(communities, key=len))) - _i = np.unique(np.array(self.i_xy_triangles)[c2].flatten()) + if ratio_superiority > 1 and len(communities) > 1: + communities.remove(set(c1)) + c2 = np.array(list(max(communities, key=len))) + _i = np.unique(np.array(self.i_xy_triangles)[c2].flatten()) - if len(i) / len(_i) < ratio_superiority: - continue + if len(i) / len(_i) < ratio_superiority: + continue - if len(i) >= minimum_matches: - break + if len(i) >= minimum_matches: + break - else: - raise TimeoutError + else: + raise TimeoutError - i_xy = np.array(self.i_xy_triangles)[c1].flatten()[i] - i_rd = np.array(self.i_rd_triangles)[c1].flatten()[i] + i_xy = np.array(self.i_xy_triangles)[c1].flatten()[i] + i_rd = np.array(self.i_rd_triangles)[c1].flatten()[i] - return list(zip(i_xy, i_rd)) + return list(zip(i_xy, i_rd)) diff --git a/flows/coordinatematch/wcs.py b/flows/coordinatematch/wcs.py index 02333d8..f851d72 100644 --- a/flows/coordinatematch/wcs.py +++ b/flows/coordinatematch/wcs.py @@ -11,217 +11,210 @@ from scipy.optimize import minimize from scipy.spatial.transform import Rotation -#-------------------------------------------------------------------------------------------------- + +# -------------------------------------------------------------------------------------------------- class WCS2(): - """ - Manipulate WCS solution. + """ + Manipulate WCS solution. + + Initialize + ---------- + wcs = WCS2(x, y, ra, dec, scale, mirror, angle) + wcs = WCS2.from_matrix(x, y, ra, dec, matrix) + wcs = WCS2.from_points(list(zip(x, y)), list(zip(ra, dec))) + wcs = WCS2.from_astropy_wcs(astropy.wcs.WCS()) - Initialize - ---------- - wcs = WCS2(x, y, ra, dec, scale, mirror, angle) - wcs = WCS2.from_matrix(x, y, ra, dec, matrix) - wcs = WCS2.from_points(list(zip(x, y)), list(zip(ra, dec))) - wcs = WCS2.from_astropy_wcs(astropy.wcs.WCS()) + ra, dec and angle should be in degrees + scale should be in arcsec/pixel + matrix should be the PC or CD matrix - ra, dec and angle should be in degrees - scale should be in arcsec/pixel - matrix should be the PC or CD matrix + Examples + -------- + Adjust x, y offset: + wcs.x += delta_x + wcs.y += delta_y - Examples - -------- - Adjust x, y offset: - wcs.x += delta_x - wcs.y += delta_y + Get scale and angle: + print(wcs.scale, wcs.angle) - Get scale and angle: - print(wcs.scale, wcs.angle) + Change an astropy.wcs.WCS (wcs) angle + wcs = WCS2(wcs)(angle=new_angle).astropy_wcs - Change an astropy.wcs.WCS (wcs) angle - wcs = WCS2(wcs)(angle=new_angle).astropy_wcs + Adjust solution with points + wcs.adjust_with_points(list(zip(x, y)), list(zip(ra, dec))) + """ - Adjust solution with points - wcs.adjust_with_points(list(zip(x, y)), list(zip(ra, dec))) - """ + # ---------------------------------------------------------------------------------------------- + def __init__(self, x, y, ra, dec, scale, mirror, angle): + self.x, self.y = x, y + self.ra, self.dec = ra, dec + self.scale = scale + self.mirror = mirror + self.angle = angle - #---------------------------------------------------------------------------------------------- - def __init__(self, x, y, ra, dec, scale, mirror, angle): - self.x, self.y = x, y - self.ra, self.dec = ra, dec - self.scale = scale - self.mirror = mirror - self.angle = angle + # ---------------------------------------------------------------------------------------------- + @classmethod + def from_matrix(cls, x, y, ra, dec, matrix): + '''Initiate the class with a matrix.''' - #---------------------------------------------------------------------------------------------- - @classmethod - def from_matrix(cls, x, y, ra, dec, matrix): - '''Initiate the class with a matrix.''' + assert np.shape(matrix) == (2, 2), 'Matrix must be 2x2' - assert np.shape(matrix) == (2, 2), \ - 'Matrix must be 2x2' + scale, mirror, angle = cls._decompose_matrix(matrix) - scale, mirror, angle = cls._decompose_matrix(matrix) + return cls(x, y, ra, dec, scale, mirror, angle) - return cls(x, y, ra, dec, scale, mirror, angle) + # ---------------------------------------------------------------------------------------------- + @classmethod + def from_points(cls, xy, rd): + """Initiate the class with at least pixel + sky coordinates.""" - #---------------------------------------------------------------------------------------------- - @classmethod - def from_points(cls, xy, rd): - """Initiate the class with at least pixel + sky coordinates.""" + assert np.shape(xy) == np.shape(rd) == (len(xy), 2) and len( + xy) > 2, 'Arguments must be lists of at least 3 sets of coordinates' - assert np.shape(xy) == np.shape(rd) == (len(xy), 2) and len(xy) > 2, \ - 'Arguments must be lists of at least 3 sets of coordinates' + xy, rd = np.array(xy), np.array(rd) - xy, rd = np.array(xy), np.array(rd) + x, y, ra, dec, matrix = cls._solve_from_points(xy, rd) + scale, mirror, angle = cls._decompose_matrix(matrix) - x, y, ra, dec, matrix = cls._solve_from_points(xy, rd) - scale, mirror, angle = cls._decompose_matrix(matrix) + return cls(x, y, ra, dec, scale, mirror, angle) - return cls(x, y, ra, dec, scale, mirror, angle) + # ---------------------------------------------------------------------------------------------- + @classmethod + def from_astropy_wcs(cls, astropy_wcs): + """Initiate the class with an astropy.wcs.WCS object.""" - #---------------------------------------------------------------------------------------------- - @classmethod - def from_astropy_wcs(cls, astropy_wcs): - """Initiate the class with an astropy.wcs.WCS object.""" + if not isinstance(astropy_wcs, astropy.wcs.WCS): + raise ValueError('Must be astropy.wcs.WCS') - if not isinstance(astropy_wcs, astropy.wcs.WCS): - raise ValueError('Must be astropy.wcs.WCS') + (x, y), (ra, dec) = astropy_wcs.wcs.crpix, astropy_wcs.wcs.crval + scale, mirror, angle = cls._decompose_matrix(astropy_wcs.pixel_scale_matrix) - (x, y), (ra, dec) = astropy_wcs.wcs.crpix, astropy_wcs.wcs.crval - scale, mirror, angle = cls._decompose_matrix( - astropy_wcs.pixel_scale_matrix) + return cls(x, y, ra, dec, scale, mirror, angle) - return cls(x, y, ra, dec, scale, mirror, angle) + # ---------------------------------------------------------------------------------------------- + def adjust_with_points(self, xy, rd): + """ + Adjust the WCS with pixel + sky coordinates. - #---------------------------------------------------------------------------------------------- - def adjust_with_points(self, xy, rd): - """ - Adjust the WCS with pixel + sky coordinates. + If one set is given the change will be a simple offset. + If two sets are given the offset, angle and scale will be derived. + And if more sets are given a completely new solution will be found. + """ - If one set is given the change will be a simple offset. - If two sets are given the offset, angle and scale will be derived. - And if more sets are given a completely new solution will be found. - """ + assert np.shape(xy) == np.shape(rd) == (len(xy), 2), 'Arguments must be lists of sets of coordinates' - assert np.shape(xy) == np.shape(rd) == (len(xy), 2), \ - 'Arguments must be lists of sets of coordinates' + xy, rd = np.array(xy), np.array(rd) - xy, rd = np.array(xy), np.array(rd) + self.x, self.y = xy.mean(axis=0) + self.ra, self.dec = rd.mean(axis=0) - self.x, self.y = xy.mean(axis=0) - self.ra, self.dec = rd.mean(axis=0) + A, b = xy - xy.mean(axis=0), rd - rd.mean(axis=0) + b[:, 0] *= np.cos(np.deg2rad(rd[:, 1])) - A, b = xy - xy.mean(axis=0), rd - rd.mean(axis=0) - b[:, 0] *= np.cos(np.deg2rad(rd[:, 1])) + if len(xy) == 2: - if len(xy) == 2: + M = np.diag([[-1, 1][self.mirror], 1]) - M = np.diag([[-1, 1][self.mirror], 1]) + def R(t): + return np.array([[np.cos(t), -np.sin(t)], [np.sin(t), np.cos(t)]]) - def R(t): - return np.array([[np.cos(t), -np.sin(t)], - [np.sin(t), np.cos(t)]]) + def chi2(x): + return np.power(A.dot(x[1] / 60 / 60 * R(x[0]).dot(M).T) - b, 2).sum() - def chi2(x): - return np.power( - A.dot(x[1] / 60 / 60 * R(x[0]).dot(M).T) - b, 2).sum() - self.angle, self.scale = minimize(chi2, [self.angle, self.scale]).x + self.angle, self.scale = minimize(chi2, [self.angle, self.scale]).x - elif len(xy) > 2: - matrix = np.linalg.lstsq(A, b, rcond=None)[0].T - self.scale, self.mirror, self.angle = self._decompose_matrix( - matrix) + elif len(xy) > 2: + matrix = np.linalg.lstsq(A, b, rcond=None)[0].T + self.scale, self.mirror, self.angle = self._decompose_matrix(matrix) - #---------------------------------------------------------------------------------------------- - @property - def matrix(self): + # ---------------------------------------------------------------------------------------------- + @property + def matrix(self): - scale = self.scale / 60 / 60 - mirror = np.diag([[-1, 1][self.mirror], 1]) - angle = np.deg2rad(self.angle) + scale = self.scale / 60 / 60 + mirror = np.diag([[-1, 1][self.mirror], 1]) + angle = np.deg2rad(self.angle) - matrix = np.array([[np.cos(angle), -np.sin(angle)], - [np.sin(angle), np.cos(angle)]]) + matrix = np.array([[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]]) - return scale * matrix @ mirror + return scale * matrix @ mirror - #---------------------------------------------------------------------------------------------- - @property - def astropy_wcs(self): - wcs = astropy.wcs.WCS() - wcs.wcs.crpix = self.x, self.y - wcs.wcs.crval = self.ra, self.dec - wcs.wcs.pc = self.matrix - return wcs + # ---------------------------------------------------------------------------------------------- + @property + def astropy_wcs(self): + wcs = astropy.wcs.WCS() + wcs.wcs.crpix = self.x, self.y + wcs.wcs.crval = self.ra, self.dec + wcs.wcs.pc = self.matrix + return wcs - #---------------------------------------------------------------------------------------------- - @staticmethod - def _solve_from_points(xy, rd): + # ---------------------------------------------------------------------------------------------- + @staticmethod + def _solve_from_points(xy, rd): - (x, y), (ra, dec) = xy.mean(axis=0), rd.mean(axis=0) + (x, y), (ra, dec) = xy.mean(axis=0), rd.mean(axis=0) - A, b = xy - xy.mean(axis=0), rd - rd.mean(axis=0) - b[:, 0] *= np.cos(np.deg2rad(rd[:, 1])) + A, b = xy - xy.mean(axis=0), rd - rd.mean(axis=0) + b[:, 0] *= np.cos(np.deg2rad(rd[:, 1])) - matrix = np.linalg.lstsq(A, b, rcond=None)[0].T + matrix = np.linalg.lstsq(A, b, rcond=None)[0].T - return x, y, ra, dec, matrix + return x, y, ra, dec, matrix - #---------------------------------------------------------------------------------------------- - @staticmethod - def _decompose_matrix(matrix): + # ---------------------------------------------------------------------------------------------- + @staticmethod + def _decompose_matrix(matrix): - scale = np.sqrt(np.power(matrix, 2).sum() / 2) * 60 * 60 + scale = np.sqrt(np.power(matrix, 2).sum() / 2) * 60 * 60 - if np.argmax(np.power(matrix[0], 2)): - mirror = True if np.sign(matrix[0, 1]) != np.sign( - matrix[1, 0]) else False - else: - mirror = True if np.sign(matrix[0, 0]) == np.sign( - matrix[1, 1]) else False + if np.argmax(np.power(matrix[0], 2)): + mirror = True if np.sign(matrix[0, 1]) != np.sign(matrix[1, 0]) else False + else: + mirror = True if np.sign(matrix[0, 0]) == np.sign(matrix[1, 1]) else False - matrix = matrix if mirror else matrix.dot(np.diag([-1, 1])) + matrix = matrix if mirror else matrix.dot(np.diag([-1, 1])) - matrix3d = np.eye(3) - matrix3d[:2, :2] = matrix / (scale / 60 / 60) - angle = Rotation.from_matrix(matrix3d).as_euler('xyz', degrees=True)[2] + matrix3d = np.eye(3) + matrix3d[:2, :2] = matrix / (scale / 60 / 60) + angle = Rotation.from_matrix(matrix3d).as_euler('xyz', degrees=True)[2] - return scale, mirror, angle + return scale, mirror, angle - #---------------------------------------------------------------------------------------------- - def __setattr__(self, name, value): + # ---------------------------------------------------------------------------------------------- + def __setattr__(self, name, value): - if name == 'ra' and (value < 0 or value >= 360): - raise ValueError("0 <= R.A. < 360") + if name == 'ra' and (value < 0 or value >= 360): + raise ValueError("0 <= R.A. < 360") - elif name == 'dec' and (value < -180 or value > 180): - raise ValueError("-180 <= Dec. <= 180") + elif name == 'dec' and (value < -180 or value > 180): + raise ValueError("-180 <= Dec. <= 180") - elif name == 'scale' and value <= 0: - raise ValueError("Scale > 0") + elif name == 'scale' and value <= 0: + raise ValueError("Scale > 0") - elif name == 'mirror' and not isinstance(value, bool): - raise ValueError('mirror must be boolean') + elif name == 'mirror' and not isinstance(value, bool): + raise ValueError('mirror must be boolean') - elif name == 'angle' and (value <= -180 or value > 180): - raise ValueError("-180 < Angle <= 180") + elif name == 'angle' and (value <= -180 or value > 180): + raise ValueError("-180 < Angle <= 180") - super().__setattr__(name, value) + super().__setattr__(name, value) - #---------------------------------------------------------------------------------------------- - def __call__(self, **kwargs): - '''Make a copy with, or a copy with changes.''' + # ---------------------------------------------------------------------------------------------- + def __call__(self, **kwargs): + '''Make a copy with, or a copy with changes.''' - keys = ('x', 'y', 'ra', 'dec', 'scale', 'mirror', 'angle') + keys = ('x', 'y', 'ra', 'dec', 'scale', 'mirror', 'angle') - if not all(k in keys for k in kwargs): - raise ValueError('unknown argument(s)') + if not all(k in keys for k in kwargs): + raise ValueError('unknown argument(s)') - obj = deepcopy(self) - for k, v in kwargs.items(): - obj.__setattr__(k, v) - return obj + obj = deepcopy(self) + for k, v in kwargs.items(): + obj.__setattr__(k, v) + return obj - #---------------------------------------------------------------------------------------------- - def __repr__(self): - ra, dec = self.astropy_wcs.wcs_pix2world([(0, 0)], 0)[0] - return f'WCS2(0, 0, {ra:.4f}, {dec:.4f}, {self.scale:.2f}, {self.mirror}, {self.angle:.2f})' + # ---------------------------------------------------------------------------------------------- + def __repr__(self): + ra, dec = self.astropy_wcs.wcs_pix2world([(0, 0)], 0)[0] + return f'WCS2(0, 0, {ra:.4f}, {dec:.4f}, {self.scale:.2f}, {self.mirror}, {self.angle:.2f})' diff --git a/flows/epsfbuilder/epsfbuilder.py b/flows/epsfbuilder/epsfbuilder.py index 32c48d2..d8f8ebd 100644 --- a/flows/epsfbuilder/epsfbuilder.py +++ b/flows/epsfbuilder/epsfbuilder.py @@ -10,51 +10,44 @@ from scipy.interpolate import griddata import photutils.psf -class FlowsEPSFBuilder(photutils.psf.EPSFBuilder): - def _create_initial_epsf(self, stars): - - epsf = super()._create_initial_epsf(stars) - epsf.origin = None - X, Y = np.meshgrid(*map(np.arange, epsf.shape[::-1])) - - X = X / epsf.oversampling[0] - epsf.x_origin - Y = Y / epsf.oversampling[1] - epsf.y_origin +class FlowsEPSFBuilder(photutils.psf.EPSFBuilder): + def _create_initial_epsf(self, stars): + epsf = super()._create_initial_epsf(stars) + epsf.origin = None - self._epsf_xy_grid = X, Y + X, Y = np.meshgrid(*map(np.arange, epsf.shape[::-1])) - return epsf + X = X / epsf.oversampling[0] - epsf.x_origin + Y = Y / epsf.oversampling[1] - epsf.y_origin - def _resample_residual(self, star, epsf): + self._epsf_xy_grid = X, Y - #max_dist = .5 / np.sqrt(np.sum(np.power(epsf.oversampling, 2))) + return epsf - #star_points = list(zip(star._xidx_centered, star._yidx_centered)) - #epsf_points = list(zip(*map(np.ravel, self._epsf_xy_grid))) + def _resample_residual(self, star, epsf): + # max_dist = .5 / np.sqrt(np.sum(np.power(epsf.oversampling, 2))) - #star_tree = cKDTree(star_points) - #dd, ii = star_tree.query(epsf_points, distance_upper_bound=max_dist) - #mask = np.isfinite(dd) + # star_points = list(zip(star._xidx_centered, star._yidx_centered)) + # epsf_points = list(zip(*map(np.ravel, self._epsf_xy_grid))) - #star_data = np.full_like(epsf.data, np.nan) - #star_data.ravel()[mask] = star._data_values_normalized[ii[mask]] + # star_tree = cKDTree(star_points) + # dd, ii = star_tree.query(epsf_points, distance_upper_bound=max_dist) + # mask = np.isfinite(dd) - star_points = list(zip(star._xidx_centered, star._yidx_centered)) - star_data = griddata(star_points, star._data_values_normalized, - self._epsf_xy_grid) + # star_data = np.full_like(epsf.data, np.nan) + # star_data.ravel()[mask] = star._data_values_normalized[ii[mask]] - return star_data - epsf._data + star_points = list(zip(star._xidx_centered, star._yidx_centered)) + star_data = griddata(star_points, star._data_values_normalized, self._epsf_xy_grid) - def __call__(self, *args, **kwargs): + return star_data - epsf._data - t0 = time.time() + def __call__(self, *args, **kwargs): + t0 = time.time() - epsf, stars = super().__call__(*args, **kwargs) + epsf, stars = super().__call__(*args, **kwargs) - epsf.fit_info = dict( - n_iter=len(self._epsf), - max_iters=self.maxiters, - time=time.time() - t0, - ) + epsf.fit_info = dict(n_iter=len(self._epsf), max_iters=self.maxiters, time=time.time() - t0, ) - return epsf, stars + return epsf, stars diff --git a/flows/load_image.py b/flows/load_image.py index c08f250..b9e3a61 100644 --- a/flows/load_image.py +++ b/flows/load_image.py @@ -16,8 +16,8 @@ from astropy.time import Time from astropy.wcs import WCS, FITSFixedWarning from typing import Tuple +from tendrils import api -from flows import api from dataclasses import dataclass # , field import typing from abc import ABC, abstractmethod diff --git a/flows/photometry.py b/flows/photometry.py index 3c7378c..173bb6a 100644 --- a/flows/photometry.py +++ b/flows/photometry.py @@ -26,737 +26,651 @@ from astropy.wcs.utils import proj_plane_pixel_area, fit_wcs_from_points from astropy.time import Time import sep +from tendrils import api +from tendrils.utils import load_config warnings.simplefilter('ignore', category=AstropyDeprecationWarning) -from photutils import CircularAperture, CircularAnnulus, aperture_photometry # noqa: E402 -from photutils.psf import EPSFFitter, BasicPSFPhotometry, DAOGroup, extract_stars # noqa: E402 -from photutils import Background2D, SExtractorBackground, MedianBackground # noqa: E402 -from photutils.utils import calc_total_error # noqa: E402 - -from . import api # noqa: E402 -from . import reference_cleaning as refclean # noqa: E402 -from .config import load_config # noqa: E402 -from .plots import plt, plot_image # noqa: E402 -from .version import get_version # noqa: E402 -from .load_image import load_image # noqa: E402 -from .run_imagematch import run_imagematch # noqa: E402 -from .zeropoint import bootstrap_outlier, sigma_from_Chauvenet # noqa: E402 -from .coordinatematch import CoordinateMatch, WCS2 # noqa: E402 -from .epsfbuilder import FlowsEPSFBuilder # noqa: E402 +from photutils import CircularAperture, CircularAnnulus, aperture_photometry # noqa: E402 +from photutils.psf import EPSFFitter, BasicPSFPhotometry, DAOGroup, extract_stars # noqa: E402 +from photutils import Background2D, SExtractorBackground, MedianBackground # noqa: E402 +from photutils.utils import calc_total_error # noqa: E402 + +from . import reference_cleaning as refclean # noqa: E402 +from .plots import plt, plot_image # noqa: E402 +from .version import get_version # noqa: E402 +from .load_image import load_image # noqa: E402 +from .run_imagematch import run_imagematch # noqa: E402 +from .zeropoint import bootstrap_outlier, sigma_from_Chauvenet # noqa: E402 +from .coordinatematch import CoordinateMatch, WCS2 # noqa: E402 +from .epsfbuilder import FlowsEPSFBuilder # noqa: E402 __version__ = get_version(pep440=False) -#-------------------------------------------------------------------------------------------------- -def photometry(fileid, output_folder=None, attempt_imagematch=True, keep_diff_fixed=False, - cm_timeout=None): - """ - Run photometry. - - Parameters: - fileid (int): File ID to process. - output_folder (str, optional): Path to directory where output should be placed. - attempt_imagematch (bool, optional): If no subtracted image is available, but a - template image is, should we attempt to run ImageMatch using standard settings. - Default=True. - keep_diff_fixed (bool, optional): Allow psf photometry to recenter when - calculating the flux for the difference image. Setting to True can help if diff - image has non-source flux in the region around the SN. - cm_timeout (float, optional): Timeout in seconds for the :class:`CoordinateMatch` algorithm. - - .. codeauthor:: Rasmus Handberg - .. codeauthor:: Emir Karamehmetoglu - .. codeauthor:: Simon Holmbo - """ - - # Settings: - ref_target_dist_limit = 10 * u.arcsec # Reference star must be further than this away to be included - - logger = logging.getLogger(__name__) - tic = default_timer() - - # Use local copy of archive if configured to do so: - config = load_config() - - # Get datafile dict from API: - datafile = api.get_datafile(fileid) - logger.debug("Datafile: %s", datafile) - targetid = datafile['targetid'] - target_name = datafile['target_name'] - photfilter = datafile['photfilter'] - - archive_local = config.get('photometry', 'archive_local', fallback=None) - if archive_local is not None: - datafile['archive_path'] = archive_local - if not os.path.isdir(datafile['archive_path']): - raise FileNotFoundError("ARCHIVE is not available: " + datafile['archive_path']) - - # Get the catalog containing the target and reference stars: - # TODO: Include proper-motion to the time of observation - catalog = api.get_catalog(targetid, output='table') - target = catalog['target'][0] - target_coord = coords.SkyCoord( - ra=target['ra'], - dec=target['decl'], - unit='deg', - frame='icrs') - - # Folder to save output: - if output_folder is None: - output_folder_root = config.get('photometry', 'output', fallback='.') - output_folder = os.path.join(output_folder_root, target_name, f'{fileid:05d}') - logger.info("Placing output in '%s'", output_folder) - os.makedirs(output_folder, exist_ok=True) - - # The paths to the science image: - filepath = os.path.join(datafile['archive_path'], datafile['path']) - - # TODO: Download datafile using API to local drive: - # TODO: Is this a security concern? - # if archive_local: - # api.download_datafile(datafile, archive_local) - - # Translate photometric filter into table column: - ref_filter = { - 'up': 'u_mag', - 'gp': 'g_mag', - 'rp': 'r_mag', - 'ip': 'i_mag', - 'zp': 'z_mag', - 'B': 'B_mag', - 'V': 'V_mag', - 'J': 'J_mag', - 'H': 'H_mag', - 'K': 'K_mag', - }.get(photfilter, None) - - if ref_filter is None: - logger.warning("Could not find filter '%s' in catalogs. Using default gp filter.", photfilter) - ref_filter = 'g_mag' - - # Load the image from the FITS file: - logger.info("Load image '%s'", filepath) - image = load_image(filepath, target_coord=target_coord) - - references = catalog['references'] - references.sort(ref_filter) - - # Check that there actually are reference stars in that filter: - if allnan(references[ref_filter]): - raise ValueError("No reference stars found in current photfilter.") - - #============================================================================================== - # BARYCENTRIC CORRECTION OF TIME - #============================================================================================== - - ltt_bary = image.obstime.light_travel_time(target_coord, ephemeris='jpl') - image.obstime = image.obstime.tdb + ltt_bary - - #============================================================================================== - # BACKGROUND ESTIMATION - #============================================================================================== - - fig, ax = plt.subplots(1, 2, figsize=(20, 18)) - plot_image(image.clean, ax=ax[0], scale='log', cbar='right', title='Image') - plot_image(image.mask, ax=ax[1], scale='linear', cbar='right', title='Mask') - fig.savefig(os.path.join(output_folder, 'original.png'), bbox_inches='tight') - plt.close(fig) - - # Estimate image background: - # Not using image.clean here, since we are redefining the mask anyway - background = Background2D(image.clean, (128, 128), - filter_size=(5, 5), - sigma_clip=SigmaClip(sigma=3.0), - bkg_estimator=SExtractorBackground(), - exclude_percentile=50.0) - - # Create background-subtracted image: - image.subclean = image.clean - background.background - - # Plot background estimation: - fig, ax = plt.subplots(1, 3, figsize=(20, 6)) - plot_image(image.clean, ax=ax[0], scale='log', title='Original') - plot_image(background.background, ax=ax[1], scale='log', title='Background') - plot_image(image.subclean, ax=ax[2], scale='log', title='Background subtracted') - fig.savefig(os.path.join(output_folder, 'background.png'), bbox_inches='tight') - plt.close(fig) - - # TODO: Is this correct?! - image.error = calc_total_error(image.clean, background.background_rms, 1.0) - - # Use sep to for soure extraction - sep_background = sep.Background(image.image, mask=image.mask) - objects = sep.extract(image.image - sep_background, - thresh=5., - err=sep_background.globalrms, - mask=image.mask, - deblend_cont=0.1, - minarea=9, - clean_param=2.0) - - # Cleanup large arrays which are no longer needed: - del background, fig, ax, sep_background, ltt_bary - gc.collect() - - #============================================================================================== - # DETECTION OF STARS AND MATCHING WITH CATALOG - #============================================================================================== - - # Account for proper motion: - replace(references['pm_ra'], np.NaN, 0) - replace(references['pm_dec'], np.NaN, 0) - refs_coord = coords.SkyCoord( - ra=references['ra'], - dec=references['decl'], - pm_ra_cosdec=references['pm_ra'], - pm_dec=references['pm_dec'], - unit='deg', - frame='icrs', - obstime=Time(2015.5, format='decimalyear')) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore", ErfaWarning) - refs_coord = refs_coord.apply_space_motion(new_obstime=image.obstime) - - # TODO: These need to be based on the instrument! - radius = 10 - fwhm_guess = 6.0 - fwhm_min = 3.5 - fwhm_max = 18.0 - - # Clean extracted stars - masked_sep_xy, sep_mask, masked_sep_rsqs = refclean.force_reject_g2d( - objects['x'], - objects['y'], - image, - get_fwhm=False, - radius=radius, - fwhm_guess=fwhm_guess, - rsq_min=0.3, - fwhm_max=fwhm_max, - fwhm_min=fwhm_min) - - logger.info("Finding new WCS solution...") - head_wcs = str(WCS2.from_astropy_wcs(image.wcs)) - logger.debug('Head WCS: %s', head_wcs) - - # Solve for new WCS - cm = CoordinateMatch( - xy=list(masked_sep_xy[sep_mask]), - rd=list(zip(refs_coord.ra.deg, refs_coord.dec.deg)), - xy_order=np.argsort( - np.power(masked_sep_xy[sep_mask] - np.array(image.shape[::-1]) / 2, - 2).sum(axis=1)), - rd_order=np.argsort(target_coord.separation(refs_coord)), - xy_nmax=100, - rd_nmax=100, - maximum_angle_distance=0.002) - - # Set timeout par to infinity unless specified. - if cm_timeout is None: - cm_timeout = float('inf') - try: - i_xy, i_rd = map(np.array, zip(*cm(5, 1.5, timeout=cm_timeout))) - except TimeoutError: - logger.warning('TimeoutError: No new WCS solution found') - except StopIteration: - logger.warning('StopIterationError: No new WCS solution found') - else: - logger.info('Found new WCS') - image.wcs = fit_wcs_from_points( - np.array(list(zip(*cm.xy[i_xy]))), - coords.SkyCoord(*map(list, zip(*cm.rd[i_rd])), unit='deg')) - del i_xy, i_rd - - used_wcs = str(WCS2.from_astropy_wcs(image.wcs)) - logger.debug('Used WCS: %s', used_wcs) - - # Calculate pixel-coordinates of references: - xy = image.wcs.all_world2pix(list(zip(refs_coord.ra.deg, refs_coord.dec.deg)), 0) - references['pixel_column'], references['pixel_row'] = x, y = list(map(np.array, zip(*xy))) - - # Clean out the references: - hsize = 10 - clean_references = references[ - (target_coord.separation(refs_coord) > ref_target_dist_limit) - & (x > hsize) & (x < (image.shape[1] - 1 - hsize)) - & (y > hsize) & (y < (image.shape[0] - 1 - hsize))] - - if not clean_references: - raise RuntimeError('No clean references in field') - - # Calculate the targets position in the image: - target_pixel_pos = image.wcs.all_world2pix([(target['ra'], target['decl'])], 0)[0] - - # Clean reference star locations - masked_fwhms, masked_ref_xys, rsq_mask, masked_rsqs = refclean.force_reject_g2d( - clean_references['pixel_column'], - clean_references['pixel_row'], - image, - get_fwhm=True, - radius=radius, - fwhm_guess=fwhm_guess, - fwhm_max=fwhm_max, - fwhm_min=fwhm_min, - rsq_min=0.15) - - # Use R^2 to more robustly determine initial FWHM guess. - # This cleaning is good when we have FEW references. - fwhm, fwhm_clean_references = refclean.clean_with_rsq_and_get_fwhm( - masked_fwhms, - masked_rsqs, - clean_references, - min_fwhm_references=2, - min_references=6, - rsq_min=0.15) - logger.info('Initial FWHM guess is %f pixels', fwhm) - - # Create plot of target and reference star positions from 2D Gaussian fits. - fig, ax = plt.subplots(1, 1, figsize=(20, 18)) - plot_image(image.subclean, ax=ax, scale='log', cbar='right', title=target_name) - ax.scatter(fwhm_clean_references['pixel_column'], fwhm_clean_references['pixel_row'], c='r', marker='o', alpha=0.3) - ax.scatter(masked_sep_xy[:, 0], masked_sep_xy[:, 1], marker='s', alpha=1.0, edgecolors='green', facecolors='none') - ax.scatter(target_pixel_pos[0], target_pixel_pos[1], marker='+', s=20, c='r') - fig.savefig(os.path.join(output_folder, 'positions_g2d.png'), bbox_inches='tight') - plt.close(fig) - - # Final clean of wcs corrected references - logger.info("Number of references before final cleaning: %d", len(clean_references)) - logger.debug('Masked R^2 values: %s', masked_rsqs[rsq_mask]) - references = refclean.get_clean_references(clean_references, masked_rsqs, rsq_ideal=0.8) - logger.info("Number of references after final cleaning: %d", len(references)) - - # Create plot of target and reference star positions: - fig, ax = plt.subplots(1, 1, figsize=(20, 18)) - plot_image(image.subclean, ax=ax, scale='log', cbar='right', title=target_name) - ax.scatter(references['pixel_column'], references['pixel_row'], c='r', marker='o', alpha=0.6) - ax.scatter(masked_sep_xy[:, 0], masked_sep_xy[:, 1], marker='s', alpha=0.6, edgecolors='green', facecolors='none') - ax.scatter(target_pixel_pos[0], target_pixel_pos[1], marker='+', s=20, c='r') - fig.savefig(os.path.join(output_folder, 'positions.png'), bbox_inches='tight') - plt.close(fig) - - # Cleanup large arrays which are no longer needed: - del fig, ax, cm - gc.collect() - - #============================================================================================== - # CREATE EFFECTIVE PSF MODEL - #============================================================================================== - - # Make cutouts of stars using extract_stars: - # Scales with FWHM - size = int(np.round(29 * fwhm / 6)) - size += 0 if size % 2 else 1 # Make sure it's a uneven number - size = max(size, 15) # Never go below 15 pixels - - # Extract stars sub-images: - xy = [tuple(masked_ref_xys[clean_references['starid'] == ref['starid']].data[0]) for ref in references] - with warnings.catch_warnings(): - warnings.simplefilter('ignore', AstropyUserWarning) - stars = extract_stars( - NDData(data=image.subclean.data, mask=image.mask), - Table(np.array(xy), names=('x', 'y')), - size=size + 6 # +6 for edge buffer - ) - - logger.info("Number of stars input to ePSF builder: %d", len(stars)) - - # Plot the stars being used for ePSF: - imgnr = 0 - nrows, ncols = 5, 5 - for k in range(int(np.ceil(len(stars) / (nrows * ncols)))): - fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(20, 20), squeeze=True) - ax = ax.ravel() - for i in range(nrows * ncols): - if imgnr > len(stars) - 1: - ax[i].axis('off') - else: - plot_image(stars[imgnr], ax=ax[i], scale='log', cmap='viridis') # FIXME (no x-ticks) - imgnr += 1 - - fig.savefig(os.path.join(output_folder, f'epsf_stars{k+1:02d}.png'), bbox_inches='tight') - plt.close(fig) - - # Build the ePSF: - epsf, stars = FlowsEPSFBuilder( - oversampling=1, - shape=1 * size, - fitter=EPSFFitter(fit_boxsize=max(int(np.round(1.5 * fwhm)), 5)), - recentering_boxsize=max(int(np.round(2 * fwhm)), 5), - norm_radius=max(fwhm, 5), - maxiters=100, - progress_bar=logger.isEnabledFor(logging.INFO) - )(stars) - logger.info('Built PSF model (%(n_iter)d/%(max_iters)d) in %(time).1f seconds', epsf.fit_info) - - # Store which stars were used in ePSF in the table: - references['used_for_epsf'] = False - references['used_for_epsf'][[star.id_label - 1 for star in stars.all_good_stars]] = True - logger.info("Number of stars used for ePSF: %d", np.sum(references['used_for_epsf'])) - - fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 15)) - plot_image(epsf.data, ax=ax1, cmap='viridis') - - fwhms = [] - bad_epsf_detected = False - for a, ax in ((0, ax3), (1, ax2)): - # Collapse the PDF along this axis: - profile = epsf.data.sum(axis=a) - itop = profile.argmax() - poffset = profile[itop] / 2 - - # Run a spline through the points, but subtract half of the peak value, and find the roots: - # We have to use a cubic spline, since roots() is not supported for other splines - # for some reason - profile_intp = UnivariateSpline(np.arange(0, len(profile)), profile - poffset, - k=3, s=0, ext=3) - lr = profile_intp.roots() - - # Plot the profile and spline: - x_fine = np.linspace(-0.5, len(profile) - 0.5, 500) - ax.plot(profile, 'k.-') - ax.plot(x_fine, profile_intp(x_fine) + poffset, 'g-') - ax.axvline(itop) - ax.set_xlim(-0.5, len(profile) - 0.5) - - # Do some sanity checks on the ePSF: - # It should pass 50% exactly twice and have the maximum inside that region. - # I.e. it should be a single gaussian-like peak - if len(lr) != 2 or itop < lr[0] or itop > lr[1]: - logger.error("Bad PSF along axis %d", a) - bad_epsf_detected = True - else: - axis_fwhm = lr[1] - lr[0] - fwhms.append(axis_fwhm) - ax.axvspan(lr[0], lr[1], facecolor='g', alpha=0.2) - - # Save the ePSF figure: - ax4.axis('off') - fig.savefig(os.path.join(output_folder, 'epsf.png'), bbox_inches='tight') - plt.close(fig) - - # There was a problem with the ePSF: - if bad_epsf_detected: - raise RuntimeError("Bad ePSF detected.") - - # Let's make the final FWHM the largest one we found: - fwhm = np.max(fwhms) - logger.info("Final FWHM based on ePSF: %f", fwhm) - - # Cleanup large arrays which are no longer needed: - del fig, ax, stars, fwhms, profile_intp - gc.collect() - - #============================================================================================== - # COORDINATES TO DO PHOTOMETRY AT - #============================================================================================== - - coordinates = np.array([[ref['pixel_column'], ref['pixel_row']] for ref in references]) - - # Add the main target position as the first entry for doing photometry directly in the - # science image: - coordinates = np.concatenate(([target_pixel_pos], coordinates), axis=0) - - #============================================================================================== - # APERTURE PHOTOMETRY - #============================================================================================== - - # Define apertures for aperture photometry: - apertures = CircularAperture(coordinates, r=fwhm) - annuli = CircularAnnulus(coordinates, r_in=1.5*fwhm, r_out=2.5*fwhm) - - apphot_tbl = aperture_photometry(image.subclean, [apertures, annuli], - mask=image.mask, error=image.error) - - logger.info('Aperture Photometry Success') - logger.debug("Aperture Photometry Table:\n%s", apphot_tbl) - - #============================================================================================== - # PSF PHOTOMETRY - #============================================================================================== - - # Create photometry object: - photometry_obj = BasicPSFPhotometry( - group_maker=DAOGroup(fwhm), - bkg_estimator=MedianBackground(), - psf_model=epsf, - fitter=fitting.LevMarLSQFitter(), - fitshape=size, - aperture_radius=fwhm) - - psfphot_tbl = photometry_obj(image=image.subclean, - init_guesses=Table(coordinates, names=['x_0', 'y_0'])) - - logger.info('PSF Photometry Success') - logger.debug("PSF Photometry Table:\n%s", psfphot_tbl) - - #============================================================================================== - # TEMPLATE SUBTRACTION AND TARGET PHOTOMETRY - #============================================================================================== - - # Find the pixel-scale of the science image: - pixel_area = proj_plane_pixel_area(image.wcs.celestial) - pixel_scale = np.sqrt(pixel_area) * 3600 # arcsec/pixel - # print(image.wcs.celestial.cunit) % Doesn't work? - logger.info("Science image pixel scale: %f", pixel_scale) - - diffimage = None - if datafile.get('diffimg') is not None: - diffimg_path = os.path.join(datafile['archive_path'], datafile['diffimg']['path']) - diffimg = load_image(diffimg_path) - diffimage = diffimg.image - - elif attempt_imagematch and datafile.get('template') is not None: - # Run the template subtraction, and get back - # the science image where the template has been subtracted: - diffimage = run_imagematch(datafile, target, - star_coord=coordinates, - fwhm=fwhm, - pixel_scale=pixel_scale) - - # We have a diff image, so let's do photometry of the target using this: - if diffimage is not None: - # Include mask from original image: - diffimage = np.ma.masked_array(diffimage, image.mask) - - # Create apertures around the target: - apertures = CircularAperture(target_pixel_pos, r=fwhm) - annuli = CircularAnnulus(target_pixel_pos, r_in=1.5*fwhm, r_out=2.5*fwhm) - - # Create two plots of the difference image: - fig, ax = plt.subplots(1, 1, squeeze=True, figsize=(20, 20)) - plot_image(diffimage, ax=ax, cbar='right', title=target_name) - ax.plot(target_pixel_pos[0], target_pixel_pos[1], marker='+', markersize=20, color='r') - fig.savefig(os.path.join(output_folder, 'diffimg.png'), bbox_inches='tight') - apertures.plot(axes=ax, color='r', lw=2) - annuli.plot(axes=ax, color='r', lw=2) - ax.set_xlim(target_pixel_pos[0] - 50, target_pixel_pos[0] + 50) - ax.set_ylim(target_pixel_pos[1] - 50, target_pixel_pos[1] + 50) - fig.savefig(os.path.join(output_folder, 'diffimg_zoom.png'), bbox_inches='tight') - plt.close(fig) - - # Run aperture photometry on subtracted image: - target_apphot_tbl = aperture_photometry(diffimage, [apertures, annuli], - mask=image.mask, - error=image.error) - - # Make target only photometry object if keep_diff_fixed = True - if keep_diff_fixed: - epsf.fixed.update({'x_0': True, 'y_0': True}) - - # TODO: Try iteraratively subtracted photometry - # Create photometry object: - photometry_obj = BasicPSFPhotometry( - group_maker=DAOGroup(0.0001), - bkg_estimator=MedianBackground(), - psf_model=epsf, - fitter=fitting.LevMarLSQFitter(), - fitshape=size, - aperture_radius=fwhm) - - # Run PSF photometry on template subtracted image: - target_psfphot_tbl = photometry_obj(diffimage, - init_guesses=Table(target_pixel_pos, names=['x_0', 'y_0'])) - - # Need to adjust table columns if x_0 and y_0 were fixed - if keep_diff_fixed: - target_psfphot_tbl['x_0_unc'] = 0.0 - target_psfphot_tbl['y_0_unc'] = 0.0 - - # Combine the output tables from the target and the reference stars into one: - apphot_tbl = vstack([target_apphot_tbl, apphot_tbl], join_type='exact') - psfphot_tbl = vstack([target_psfphot_tbl, psfphot_tbl], join_type='exact') - - # Build results table: - tab = references.copy() - - row = { - 'starid': 0, - 'ra': target['ra'], - 'decl': target['decl'], - 'pixel_column': target_pixel_pos[0], - 'pixel_row': target_pixel_pos[1], - 'used_for_epsf': False - } - row.update([(k, np.NaN) for k in set(tab.keys()) - set(row) - {'gaia_variability'}]) - tab.insert_row(0, row) - - if diffimage is not None: - row['starid'] = -1 - tab.insert_row(0, row) - - indx_main_target = tab['starid'] <= 0 - - # Subtract background estimated from annuli: - flux_aperture = apphot_tbl['aperture_sum_0'] - (apphot_tbl['aperture_sum_1'] / annuli.area) * apertures.area - flux_aperture_error = np.sqrt(apphot_tbl['aperture_sum_err_0']**2 + (apphot_tbl['aperture_sum_err_1'] / annuli.area * apertures.area)**2) - - # Add table columns with results: - tab['flux_aperture'] = flux_aperture / image.exptime - tab['flux_aperture_error'] = flux_aperture_error / image.exptime - tab['flux_psf'] = psfphot_tbl['flux_fit'] / image.exptime - tab['flux_psf_error'] = psfphot_tbl['flux_unc'] / image.exptime - tab['pixel_column_psf_fit'] = psfphot_tbl['x_fit'] - tab['pixel_row_psf_fit'] = psfphot_tbl['y_fit'] - tab['pixel_column_psf_fit_error'] = psfphot_tbl['x_0_unc'] - tab['pixel_row_psf_fit_error'] = psfphot_tbl['y_0_unc'] - - # Check that we got valid photometry: - if np.any(~np.isfinite(tab[indx_main_target]['flux_psf'])) or np.any(~np.isfinite(tab[indx_main_target]['flux_psf_error'])): - raise RuntimeError("Target magnitude is undefined.") - - #============================================================================================== - # CALIBRATE - #============================================================================================== - - # Convert PSF fluxes to magnitudes: - mag_inst = -2.5 * np.log10(tab['flux_psf']) - mag_inst_err = (2.5 / np.log(10)) * (tab['flux_psf_error'] / tab['flux_psf']) - - # Corresponding magnitudes in catalog: - mag_catalog = tab[ref_filter] - - # Mask out things that should not be used in calibration: - use_for_calibration = np.ones_like(mag_catalog, dtype='bool') - use_for_calibration[indx_main_target] = False # Do not use target for calibration - use_for_calibration[~np.isfinite(mag_inst) | ~np.isfinite(mag_catalog)] = False - - # Just creating some short-hands: - x = mag_catalog[use_for_calibration] - y = mag_inst[use_for_calibration] - yerr = mag_inst_err[use_for_calibration] - weights = 1.0 / yerr**2 - - if not any(use_for_calibration): - raise RuntimeError("No calibration stars") - - # Fit linear function with fixed slope, using sigma-clipping: - model = models.Linear1D(slope=1, fixed={'slope': True}) - fitter = fitting.FittingWithOutlierRemoval( - fitting.LinearLSQFitter(), - sigma_clip, - sigma=3.0) - best_fit, sigma_clipped = fitter(model, x, y, weights=weights) - - # Extract zero-point and estimate its error using a single weighted fit: - # I don't know why there is not an error-estimate attached directly to the Parameter? - zp = -1 * best_fit.intercept.value # Negative, because that is the way zeropoints are usually defined - - weights[sigma_clipped] = 0 # Trick to make following expression simpler - n_weights = len(weights.nonzero()[0]) - if n_weights > 1: - zp_error = np.sqrt(n_weights * nansum(weights * (y - best_fit(x))**2) / nansum(weights) / (n_weights - 1)) - else: - zp_error = np.NaN - logger.info('Leastsquare ZP = %.3f, ZP_error = %.3f', zp, zp_error) - - # Determine sigma clipping sigma according to Chauvenet method - # But don't allow less than sigma = sigmamin, setting to 1.5 for now. - # Should maybe be 2? - sigmamin = 1.5 - sig_chauv = sigma_from_Chauvenet(len(x)) - sig_chauv = sig_chauv if sig_chauv >= sigmamin else sigmamin - - # Extract zero point and error using bootstrap method - nboot = 1000 - logger.info('Running bootstrap with sigma = %.2f and n = %d', sig_chauv, nboot) - pars = bootstrap_outlier(x, y, yerr, - n=nboot, - model=model, - fitter=fitting.LinearLSQFitter, - outlier=sigma_clip, - outlier_kwargs={'sigma': sig_chauv}, - summary='median', - error='bootstrap', - return_vals=False) - - zp_bs = pars['intercept'] * -1.0 - zp_error_bs = pars['intercept_error'] - - logger.info('Bootstrapped ZP = %.3f, ZP_error = %.3f', zp_bs, zp_error_bs) - - # Check that difference is not large - zp_diff = 0.4 - if np.abs(zp_bs - zp) >= zp_diff: - logger.warning("Bootstrap and weighted LSQ ZPs differ by %.2f, \ + +# -------------------------------------------------------------------------------------------------- +def photometry(fileid, output_folder=None, attempt_imagematch=True, keep_diff_fixed=False, cm_timeout=None): + """ + Run photometry. + + Parameters: + fileid (int): File ID to process. + output_folder (str, optional): Path to directory where output should be placed. + attempt_imagematch (bool, optional): If no subtracted image is available, but a + template image is, should we attempt to run ImageMatch using standard settings. + Default=True. + keep_diff_fixed (bool, optional): Allow psf photometry to recenter when + calculating the flux for the difference image. Setting to True can help if diff + image has non-source flux in the region around the SN. + cm_timeout (float, optional): Timeout in seconds for the :class:`CoordinateMatch` algorithm. + + .. codeauthor:: Rasmus Handberg + .. codeauthor:: Emir Karamehmetoglu + .. codeauthor:: Simon Holmbo + """ + + # Settings: + ref_target_dist_limit = 10 * u.arcsec # Reference star must be further than this away to be included + + logger = logging.getLogger(__name__) + tic = default_timer() + + # Use local copy of archive if configured to do so: + config = load_config() + + # Get datafile dict from API: + datafile = api.get_datafile(fileid) + logger.debug("Datafile: %s", datafile) + targetid = datafile['targetid'] + target_name = datafile['target_name'] + photfilter = datafile['photfilter'] + + archive_local = config.get('photometry', 'archive_local', fallback=None) + if archive_local is not None: + datafile['archive_path'] = archive_local + if not os.path.isdir(datafile['archive_path']): + raise FileNotFoundError("ARCHIVE is not available: " + datafile['archive_path']) + + # Get the catalog containing the target and reference stars: + # TODO: Include proper-motion to the time of observation + catalog = api.get_catalog(targetid, output='table') + target = catalog['target'][0] + target_coord = coords.SkyCoord(ra=target['ra'], dec=target['decl'], unit='deg', frame='icrs') + + # Folder to save output: + if output_folder is None: + output_folder_root = config.get('photometry', 'output', fallback='.') + output_folder = os.path.join(output_folder_root, target_name, f'{fileid:05d}') + logger.info("Placing output in '%s'", output_folder) + os.makedirs(output_folder, exist_ok=True) + + # The paths to the science image: + filepath = os.path.join(datafile['archive_path'], datafile['path']) + + # TODO: Download datafile using API to local drive: + # TODO: Is this a security concern? + # if archive_local: + # api.download_datafile(datafile, archive_local) + + # Translate photometric filter into table column: + ref_filter = {'up': 'u_mag', 'gp': 'g_mag', 'rp': 'r_mag', 'ip': 'i_mag', 'zp': 'z_mag', 'B': 'B_mag', 'V': 'V_mag', + 'J': 'J_mag', 'H': 'H_mag', 'K': 'K_mag', }.get(photfilter, None) + + if ref_filter is None: + logger.warning("Could not find filter '%s' in catalogs. Using default gp filter.", photfilter) + ref_filter = 'g_mag' + + # Load the image from the FITS file: + logger.info("Load image '%s'", filepath) + image = load_image(filepath, target_coord=target_coord) + + references = catalog['references'] + references.sort(ref_filter) + + # Check that there actually are reference stars in that filter: + if allnan(references[ref_filter]): + raise ValueError("No reference stars found in current photfilter.") + + # ============================================================================================== + # BARYCENTRIC CORRECTION OF TIME + # ============================================================================================== + + ltt_bary = image.obstime.light_travel_time(target_coord, ephemeris='jpl') + image.obstime = image.obstime.tdb + ltt_bary + + # ============================================================================================== + # BACKGROUND ESTIMATION + # ============================================================================================== + + fig, ax = plt.subplots(1, 2, figsize=(20, 18)) + plot_image(image.clean, ax=ax[0], scale='log', cbar='right', title='Image') + plot_image(image.mask, ax=ax[1], scale='linear', cbar='right', title='Mask') + fig.savefig(os.path.join(output_folder, 'original.png'), bbox_inches='tight') + plt.close(fig) + + # Estimate image background: + # Not using image.clean here, since we are redefining the mask anyway + background = Background2D(image.clean, (128, 128), filter_size=(5, 5), sigma_clip=SigmaClip(sigma=3.0), + bkg_estimator=SExtractorBackground(), exclude_percentile=50.0) + + # Create background-subtracted image: + image.subclean = image.clean - background.background + + # Plot background estimation: + fig, ax = plt.subplots(1, 3, figsize=(20, 6)) + plot_image(image.clean, ax=ax[0], scale='log', title='Original') + plot_image(background.background, ax=ax[1], scale='log', title='Background') + plot_image(image.subclean, ax=ax[2], scale='log', title='Background subtracted') + fig.savefig(os.path.join(output_folder, 'background.png'), bbox_inches='tight') + plt.close(fig) + + # TODO: Is this correct?! + image.error = calc_total_error(image.clean, background.background_rms, 1.0) + + # Use sep to for soure extraction + sep_background = sep.Background(image.image, mask=image.mask) + objects = sep.extract(image.image - sep_background, thresh=5., err=sep_background.globalrms, mask=image.mask, + deblend_cont=0.1, minarea=9, clean_param=2.0) + + # Cleanup large arrays which are no longer needed: + del background, fig, ax, sep_background, ltt_bary + gc.collect() + + # ============================================================================================== + # DETECTION OF STARS AND MATCHING WITH CATALOG + # ============================================================================================== + + # Account for proper motion: + replace(references['pm_ra'], np.NaN, 0) + replace(references['pm_dec'], np.NaN, 0) + refs_coord = coords.SkyCoord(ra=references['ra'], dec=references['decl'], pm_ra_cosdec=references['pm_ra'], + pm_dec=references['pm_dec'], unit='deg', frame='icrs', + obstime=Time(2015.5, format='decimalyear')) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", ErfaWarning) + refs_coord = refs_coord.apply_space_motion(new_obstime=image.obstime) + + # TODO: These need to be based on the instrument! + radius = 10 + fwhm_guess = 6.0 + fwhm_min = 3.5 + fwhm_max = 18.0 + + # Clean extracted stars + masked_sep_xy, sep_mask, masked_sep_rsqs = refclean.force_reject_g2d(objects['x'], objects['y'], image, + get_fwhm=False, radius=radius, + fwhm_guess=fwhm_guess, rsq_min=0.3, + fwhm_max=fwhm_max, fwhm_min=fwhm_min) + + logger.info("Finding new WCS solution...") + head_wcs = str(WCS2.from_astropy_wcs(image.wcs)) + logger.debug('Head WCS: %s', head_wcs) + + # Solve for new WCS + cm = CoordinateMatch(xy=list(masked_sep_xy[sep_mask]), rd=list(zip(refs_coord.ra.deg, refs_coord.dec.deg)), + xy_order=np.argsort( + np.power(masked_sep_xy[sep_mask] - np.array(image.shape[::-1]) / 2, 2).sum(axis=1)), + rd_order=np.argsort(target_coord.separation(refs_coord)), xy_nmax=100, rd_nmax=100, + maximum_angle_distance=0.002) + + # Set timeout par to infinity unless specified. + if cm_timeout is None: + cm_timeout = float('inf') + try: + i_xy, i_rd = map(np.array, zip(*cm(5, 1.5, timeout=cm_timeout))) + except TimeoutError: + logger.warning('TimeoutError: No new WCS solution found') + except StopIteration: + logger.warning('StopIterationError: No new WCS solution found') + else: + logger.info('Found new WCS') + image.wcs = fit_wcs_from_points(np.array(list(zip(*cm.xy[i_xy]))), + coords.SkyCoord(*map(list, zip(*cm.rd[i_rd])), unit='deg')) + del i_xy, i_rd + + used_wcs = str(WCS2.from_astropy_wcs(image.wcs)) + logger.debug('Used WCS: %s', used_wcs) + + # Calculate pixel-coordinates of references: + xy = image.wcs.all_world2pix(list(zip(refs_coord.ra.deg, refs_coord.dec.deg)), 0) + references['pixel_column'], references['pixel_row'] = x, y = list(map(np.array, zip(*xy))) + + # Clean out the references: + hsize = 10 + clean_references = references[(target_coord.separation(refs_coord) > ref_target_dist_limit) & (x > hsize) & ( + x < (image.shape[1] - 1 - hsize)) & (y > hsize) & (y < (image.shape[0] - 1 - hsize))] + + if not clean_references: + raise RuntimeError('No clean references in field') + + # Calculate the targets position in the image: + target_pixel_pos = image.wcs.all_world2pix([(target['ra'], target['decl'])], 0)[0] + + # Clean reference star locations + masked_fwhms, masked_ref_xys, rsq_mask, masked_rsqs = refclean.force_reject_g2d(clean_references['pixel_column'], + clean_references['pixel_row'], + image, get_fwhm=True, radius=radius, + fwhm_guess=fwhm_guess, + fwhm_max=fwhm_max, + fwhm_min=fwhm_min, rsq_min=0.15) + + # Use R^2 to more robustly determine initial FWHM guess. + # This cleaning is good when we have FEW references. + fwhm, fwhm_clean_references = refclean.clean_with_rsq_and_get_fwhm(masked_fwhms, masked_rsqs, clean_references, + min_fwhm_references=2, min_references=6, + rsq_min=0.15) + logger.info('Initial FWHM guess is %f pixels', fwhm) + + # Create plot of target and reference star positions from 2D Gaussian fits. + fig, ax = plt.subplots(1, 1, figsize=(20, 18)) + plot_image(image.subclean, ax=ax, scale='log', cbar='right', title=target_name) + ax.scatter(fwhm_clean_references['pixel_column'], fwhm_clean_references['pixel_row'], c='r', marker='o', alpha=0.3) + ax.scatter(masked_sep_xy[:, 0], masked_sep_xy[:, 1], marker='s', alpha=1.0, edgecolors='green', facecolors='none') + ax.scatter(target_pixel_pos[0], target_pixel_pos[1], marker='+', s=20, c='r') + fig.savefig(os.path.join(output_folder, 'positions_g2d.png'), bbox_inches='tight') + plt.close(fig) + + # Final clean of wcs corrected references + logger.info("Number of references before final cleaning: %d", len(clean_references)) + logger.debug('Masked R^2 values: %s', masked_rsqs[rsq_mask]) + references = refclean.get_clean_references(clean_references, masked_rsqs, rsq_ideal=0.8) + logger.info("Number of references after final cleaning: %d", len(references)) + + # Create plot of target and reference star positions: + fig, ax = plt.subplots(1, 1, figsize=(20, 18)) + plot_image(image.subclean, ax=ax, scale='log', cbar='right', title=target_name) + ax.scatter(references['pixel_column'], references['pixel_row'], c='r', marker='o', alpha=0.6) + ax.scatter(masked_sep_xy[:, 0], masked_sep_xy[:, 1], marker='s', alpha=0.6, edgecolors='green', facecolors='none') + ax.scatter(target_pixel_pos[0], target_pixel_pos[1], marker='+', s=20, c='r') + fig.savefig(os.path.join(output_folder, 'positions.png'), bbox_inches='tight') + plt.close(fig) + + # Cleanup large arrays which are no longer needed: + del fig, ax, cm + gc.collect() + + # ============================================================================================== + # CREATE EFFECTIVE PSF MODEL + # ============================================================================================== + + # Make cutouts of stars using extract_stars: + # Scales with FWHM + size = int(np.round(29 * fwhm / 6)) + size += 0 if size % 2 else 1 # Make sure it's a uneven number + size = max(size, 15) # Never go below 15 pixels + + # Extract stars sub-images: + xy = [tuple(masked_ref_xys[clean_references['starid'] == ref['starid']].data[0]) for ref in references] + with warnings.catch_warnings(): + warnings.simplefilter('ignore', AstropyUserWarning) + stars = extract_stars(NDData(data=image.subclean.data, mask=image.mask), Table(np.array(xy), names=('x', 'y')), + size=size + 6 # +6 for edge buffer + ) + + logger.info("Number of stars input to ePSF builder: %d", len(stars)) + + # Plot the stars being used for ePSF: + imgnr = 0 + nrows, ncols = 5, 5 + for k in range(int(np.ceil(len(stars) / (nrows * ncols)))): + fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(20, 20), squeeze=True) + ax = ax.ravel() + for i in range(nrows * ncols): + if imgnr > len(stars) - 1: + ax[i].axis('off') + else: + plot_image(stars[imgnr], ax=ax[i], scale='log', cmap='viridis') # FIXME (no x-ticks) + imgnr += 1 + + fig.savefig(os.path.join(output_folder, f'epsf_stars{k + 1:02d}.png'), bbox_inches='tight') + plt.close(fig) + + # Build the ePSF: + epsf, stars = FlowsEPSFBuilder(oversampling=1, shape=1 * size, + fitter=EPSFFitter(fit_boxsize=max(int(np.round(1.5 * fwhm)), 5)), + recentering_boxsize=max(int(np.round(2 * fwhm)), 5), norm_radius=max(fwhm, 5), + maxiters=100, progress_bar=logger.isEnabledFor(logging.INFO))(stars) + logger.info('Built PSF model (%(n_iter)d/%(max_iters)d) in %(time).1f seconds', epsf.fit_info) + + # Store which stars were used in ePSF in the table: + references['used_for_epsf'] = False + references['used_for_epsf'][[star.id_label - 1 for star in stars.all_good_stars]] = True + logger.info("Number of stars used for ePSF: %d", np.sum(references['used_for_epsf'])) + + fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 15)) + plot_image(epsf.data, ax=ax1, cmap='viridis') + + fwhms = [] + bad_epsf_detected = False + for a, ax in ((0, ax3), (1, ax2)): + # Collapse the PDF along this axis: + profile = epsf.data.sum(axis=a) + itop = profile.argmax() + poffset = profile[itop] / 2 + + # Run a spline through the points, but subtract half of the peak value, and find the roots: + # We have to use a cubic spline, since roots() is not supported for other splines + # for some reason + profile_intp = UnivariateSpline(np.arange(0, len(profile)), profile - poffset, k=3, s=0, ext=3) + lr = profile_intp.roots() + + # Plot the profile and spline: + x_fine = np.linspace(-0.5, len(profile) - 0.5, 500) + ax.plot(profile, 'k.-') + ax.plot(x_fine, profile_intp(x_fine) + poffset, 'g-') + ax.axvline(itop) + ax.set_xlim(-0.5, len(profile) - 0.5) + + # Do some sanity checks on the ePSF: + # It should pass 50% exactly twice and have the maximum inside that region. + # I.e. it should be a single gaussian-like peak + if len(lr) != 2 or itop < lr[0] or itop > lr[1]: + logger.error("Bad PSF along axis %d", a) + bad_epsf_detected = True + else: + axis_fwhm = lr[1] - lr[0] + fwhms.append(axis_fwhm) + ax.axvspan(lr[0], lr[1], facecolor='g', alpha=0.2) + + # Save the ePSF figure: + ax4.axis('off') + fig.savefig(os.path.join(output_folder, 'epsf.png'), bbox_inches='tight') + plt.close(fig) + + # There was a problem with the ePSF: + if bad_epsf_detected: + raise RuntimeError("Bad ePSF detected.") + + # Let's make the final FWHM the largest one we found: + fwhm = np.max(fwhms) + logger.info("Final FWHM based on ePSF: %f", fwhm) + + # Cleanup large arrays which are no longer needed: + del fig, ax, stars, fwhms, profile_intp + gc.collect() + + # ============================================================================================== + # COORDINATES TO DO PHOTOMETRY AT + # ============================================================================================== + + coordinates = np.array([[ref['pixel_column'], ref['pixel_row']] for ref in references]) + + # Add the main target position as the first entry for doing photometry directly in the + # science image: + coordinates = np.concatenate(([target_pixel_pos], coordinates), axis=0) + + # ============================================================================================== + # APERTURE PHOTOMETRY + # ============================================================================================== + + # Define apertures for aperture photometry: + apertures = CircularAperture(coordinates, r=fwhm) + annuli = CircularAnnulus(coordinates, r_in=1.5 * fwhm, r_out=2.5 * fwhm) + + apphot_tbl = aperture_photometry(image.subclean, [apertures, annuli], mask=image.mask, error=image.error) + + logger.info('Aperture Photometry Success') + logger.debug("Aperture Photometry Table:\n%s", apphot_tbl) + + # ============================================================================================== + # PSF PHOTOMETRY + # ============================================================================================== + + # Create photometry object: + photometry_obj = BasicPSFPhotometry(group_maker=DAOGroup(fwhm), bkg_estimator=MedianBackground(), psf_model=epsf, + fitter=fitting.LevMarLSQFitter(), fitshape=size, aperture_radius=fwhm) + + psfphot_tbl = photometry_obj(image=image.subclean, init_guesses=Table(coordinates, names=['x_0', 'y_0'])) + + logger.info('PSF Photometry Success') + logger.debug("PSF Photometry Table:\n%s", psfphot_tbl) + + # ============================================================================================== + # TEMPLATE SUBTRACTION AND TARGET PHOTOMETRY + # ============================================================================================== + + # Find the pixel-scale of the science image: + pixel_area = proj_plane_pixel_area(image.wcs.celestial) + pixel_scale = np.sqrt(pixel_area) * 3600 # arcsec/pixel + # print(image.wcs.celestial.cunit) % Doesn't work? + logger.info("Science image pixel scale: %f", pixel_scale) + + diffimage = None + if datafile.get('diffimg') is not None: + diffimg_path = os.path.join(datafile['archive_path'], datafile['diffimg']['path']) + diffimg = load_image(diffimg_path) + diffimage = diffimg.image + + elif attempt_imagematch and datafile.get('template') is not None: + # Run the template subtraction, and get back + # the science image where the template has been subtracted: + diffimage = run_imagematch(datafile, target, star_coord=coordinates, fwhm=fwhm, pixel_scale=pixel_scale) + + # We have a diff image, so let's do photometry of the target using this: + if diffimage is not None: + # Include mask from original image: + diffimage = np.ma.masked_array(diffimage, image.mask) + + # Create apertures around the target: + apertures = CircularAperture(target_pixel_pos, r=fwhm) + annuli = CircularAnnulus(target_pixel_pos, r_in=1.5 * fwhm, r_out=2.5 * fwhm) + + # Create two plots of the difference image: + fig, ax = plt.subplots(1, 1, squeeze=True, figsize=(20, 20)) + plot_image(diffimage, ax=ax, cbar='right', title=target_name) + ax.plot(target_pixel_pos[0], target_pixel_pos[1], marker='+', markersize=20, color='r') + fig.savefig(os.path.join(output_folder, 'diffimg.png'), bbox_inches='tight') + apertures.plot(axes=ax, color='r', lw=2) + annuli.plot(axes=ax, color='r', lw=2) + ax.set_xlim(target_pixel_pos[0] - 50, target_pixel_pos[0] + 50) + ax.set_ylim(target_pixel_pos[1] - 50, target_pixel_pos[1] + 50) + fig.savefig(os.path.join(output_folder, 'diffimg_zoom.png'), bbox_inches='tight') + plt.close(fig) + + # Run aperture photometry on subtracted image: + target_apphot_tbl = aperture_photometry(diffimage, [apertures, annuli], mask=image.mask, error=image.error) + + # Make target only photometry object if keep_diff_fixed = True + if keep_diff_fixed: + epsf.fixed.update({'x_0': True, 'y_0': True}) + + # TODO: Try iteraratively subtracted photometry + # Create photometry object: + photometry_obj = BasicPSFPhotometry(group_maker=DAOGroup(0.0001), bkg_estimator=MedianBackground(), + psf_model=epsf, fitter=fitting.LevMarLSQFitter(), fitshape=size, + aperture_radius=fwhm) + + # Run PSF photometry on template subtracted image: + target_psfphot_tbl = photometry_obj(diffimage, init_guesses=Table(target_pixel_pos, names=['x_0', 'y_0'])) + + # Need to adjust table columns if x_0 and y_0 were fixed + if keep_diff_fixed: + target_psfphot_tbl['x_0_unc'] = 0.0 + target_psfphot_tbl['y_0_unc'] = 0.0 + + # Combine the output tables from the target and the reference stars into one: + apphot_tbl = vstack([target_apphot_tbl, apphot_tbl], join_type='exact') + psfphot_tbl = vstack([target_psfphot_tbl, psfphot_tbl], join_type='exact') + + # Build results table: + tab = references.copy() + + row = {'starid': 0, 'ra': target['ra'], 'decl': target['decl'], 'pixel_column': target_pixel_pos[0], + 'pixel_row': target_pixel_pos[1], 'used_for_epsf': False} + row.update([(k, np.NaN) for k in set(tab.keys()) - set(row) - {'gaia_variability'}]) + tab.insert_row(0, row) + + if diffimage is not None: + row['starid'] = -1 + tab.insert_row(0, row) + + indx_main_target = tab['starid'] <= 0 + + # Subtract background estimated from annuli: + flux_aperture = apphot_tbl['aperture_sum_0'] - (apphot_tbl['aperture_sum_1'] / annuli.area) * apertures.area + flux_aperture_error = np.sqrt( + apphot_tbl['aperture_sum_err_0'] ** 2 + (apphot_tbl['aperture_sum_err_1'] / annuli.area * apertures.area) ** 2) + + # Add table columns with results: + tab['flux_aperture'] = flux_aperture / image.exptime + tab['flux_aperture_error'] = flux_aperture_error / image.exptime + tab['flux_psf'] = psfphot_tbl['flux_fit'] / image.exptime + tab['flux_psf_error'] = psfphot_tbl['flux_unc'] / image.exptime + tab['pixel_column_psf_fit'] = psfphot_tbl['x_fit'] + tab['pixel_row_psf_fit'] = psfphot_tbl['y_fit'] + tab['pixel_column_psf_fit_error'] = psfphot_tbl['x_0_unc'] + tab['pixel_row_psf_fit_error'] = psfphot_tbl['y_0_unc'] + + # Check that we got valid photometry: + if np.any(~np.isfinite(tab[indx_main_target]['flux_psf'])) or np.any( + ~np.isfinite(tab[indx_main_target]['flux_psf_error'])): + raise RuntimeError("Target magnitude is undefined.") + + # ============================================================================================== + # CALIBRATE + # ============================================================================================== + + # Convert PSF fluxes to magnitudes: + mag_inst = -2.5 * np.log10(tab['flux_psf']) + mag_inst_err = (2.5 / np.log(10)) * (tab['flux_psf_error'] / tab['flux_psf']) + + # Corresponding magnitudes in catalog: + mag_catalog = tab[ref_filter] + + # Mask out things that should not be used in calibration: + use_for_calibration = np.ones_like(mag_catalog, dtype='bool') + use_for_calibration[indx_main_target] = False # Do not use target for calibration + use_for_calibration[~np.isfinite(mag_inst) | ~np.isfinite(mag_catalog)] = False + + # Just creating some short-hands: + x = mag_catalog[use_for_calibration] + y = mag_inst[use_for_calibration] + yerr = mag_inst_err[use_for_calibration] + weights = 1.0 / yerr ** 2 + + if not any(use_for_calibration): + raise RuntimeError("No calibration stars") + + # Fit linear function with fixed slope, using sigma-clipping: + model = models.Linear1D(slope=1, fixed={'slope': True}) + fitter = fitting.FittingWithOutlierRemoval(fitting.LinearLSQFitter(), sigma_clip, sigma=3.0) + best_fit, sigma_clipped = fitter(model, x, y, weights=weights) + + # Extract zero-point and estimate its error using a single weighted fit: + # I don't know why there is not an error-estimate attached directly to the Parameter? + zp = -1 * best_fit.intercept.value # Negative, because that is the way zeropoints are usually defined + + weights[sigma_clipped] = 0 # Trick to make following expression simpler + n_weights = len(weights.nonzero()[0]) + if n_weights > 1: + zp_error = np.sqrt(n_weights * nansum(weights * (y - best_fit(x)) ** 2) / nansum(weights) / (n_weights - 1)) + else: + zp_error = np.NaN + logger.info('Leastsquare ZP = %.3f, ZP_error = %.3f', zp, zp_error) + + # Determine sigma clipping sigma according to Chauvenet method + # But don't allow less than sigma = sigmamin, setting to 1.5 for now. + # Should maybe be 2? + sigmamin = 1.5 + sig_chauv = sigma_from_Chauvenet(len(x)) + sig_chauv = sig_chauv if sig_chauv >= sigmamin else sigmamin + + # Extract zero point and error using bootstrap method + nboot = 1000 + logger.info('Running bootstrap with sigma = %.2f and n = %d', sig_chauv, nboot) + pars = bootstrap_outlier(x, y, yerr, n=nboot, model=model, fitter=fitting.LinearLSQFitter, outlier=sigma_clip, + outlier_kwargs={'sigma': sig_chauv}, summary='median', error='bootstrap', + return_vals=False) + + zp_bs = pars['intercept'] * -1.0 + zp_error_bs = pars['intercept_error'] + + logger.info('Bootstrapped ZP = %.3f, ZP_error = %.3f', zp_bs, zp_error_bs) + + # Check that difference is not large + zp_diff = 0.4 + if np.abs(zp_bs - zp) >= zp_diff: + logger.warning("Bootstrap and weighted LSQ ZPs differ by %.2f, \ which is more than the allowed %.2f mag.", np.abs(zp_bs - zp), zp_diff) - # Add calibrated magnitudes to the photometry table: - tab['mag'] = mag_inst + zp_bs - tab['mag_error'] = np.sqrt(mag_inst_err**2 + zp_error_bs**2) - - fig, ax = plt.subplots(1, 1) - ax.errorbar(x, y, yerr=yerr, fmt='k.') - ax.scatter(x[sigma_clipped], y[sigma_clipped], marker='x', c='r') - ax.plot(x, best_fit(x), color='g', linewidth=3) - ax.set_xlabel('Catalog magnitude') - ax.set_ylabel('Instrumental magnitude') - fig.savefig(os.path.join(output_folder, 'calibration.png'), bbox_inches='tight') - plt.close(fig) - - # Check that we got valid photometry: - if not np.isfinite(tab[0]['mag']) or not np.isfinite(tab[0]['mag_error']): - raise RuntimeError("Target magnitude is undefined.") - - #============================================================================================== - # SAVE PHOTOMETRY - #============================================================================================== - - # Descriptions of columns: - tab['used_for_epsf'].description = 'Was object used for building ePSF?' - tab['mag'].description = 'Measured magnitude' - tab['mag'].unit = u.mag - tab['mag_error'].description = 'Error on measured magnitude' - tab['mag_error'].unit = u.mag - tab['flux_aperture'].description = 'Measured flux using aperture photometry' - tab['flux_aperture'].unit = u.count / u.second - tab['flux_aperture_error'].description = 'Error on measured flux using aperture photometry' - tab['flux_aperture_error'].unit = u.count / u.second - tab['flux_psf'].description = 'Measured flux using PSF photometry' - tab['flux_psf'].unit = u.count / u.second - tab['flux_psf_error'].description = 'Error on measured flux using PSF photometry' - tab['flux_psf_error'].unit = u.count / u.second - tab['pixel_column'].description = 'Location on image pixel columns' - tab['pixel_column'].unit = u.pixel - tab['pixel_row'].description = 'Location on image pixel rows' - tab['pixel_row'].unit = u.pixel - tab['pixel_column_psf_fit'].description = 'Measured location on image pixel columns from PSF photometry' - tab['pixel_column_psf_fit'].unit = u.pixel - tab['pixel_column_psf_fit_error'].description = 'Error on measured location on image pixel columns from PSF photometry' - tab['pixel_column_psf_fit_error'].unit = u.pixel - tab['pixel_row_psf_fit'].description = 'Measured location on image pixel rows from PSF photometry' - tab['pixel_row_psf_fit'].unit = u.pixel - tab['pixel_row_psf_fit_error'].description = 'Error on measured location on image pixel rows from PSF photometry' - tab['pixel_row_psf_fit_error'].unit = u.pixel - - # Meta-data: - tab.meta['fileid'] = fileid - tab.meta['target_name'] = target_name - tab.meta['version'] = __version__ - tab.meta['template'] = None if datafile.get('template') is None else datafile['template']['fileid'] - tab.meta['diffimg'] = None if datafile.get('diffimg') is None else datafile['diffimg']['fileid'] - tab.meta['photfilter'] = photfilter - tab.meta['fwhm'] = fwhm * u.pixel - tab.meta['pixel_scale'] = pixel_scale * u.arcsec / u.pixel - tab.meta['seeing'] = (fwhm * pixel_scale) * u.arcsec - tab.meta['obstime-bmjd'] = float(image.obstime.mjd) - tab.meta['zp'] = zp_bs - tab.meta['zp_error'] = zp_error_bs - tab.meta['zp_diff'] = np.abs(zp_bs - zp) - tab.meta['zp_error_weights'] = zp_error - tab.meta['head_wcs'] = head_wcs # TODO: Are these really useful? - tab.meta['used_wcs'] = used_wcs # TODO: Are these really useful? - - # Filepath where to save photometry: - photometry_output = os.path.join(output_folder, 'photometry.ecsv') - - # Write the final table to file: - tab.write(photometry_output, format='ascii.ecsv', delimiter=',', overwrite=True) - - toc = default_timer() - - logger.info("------------------------------------------------------") - logger.info("Success!") - logger.info("Main target: %f +/- %f", tab[0]['mag'], tab[0]['mag_error']) - logger.info("Photometry took: %.1f seconds", toc - tic) - - return photometry_output + # Add calibrated magnitudes to the photometry table: + tab['mag'] = mag_inst + zp_bs + tab['mag_error'] = np.sqrt(mag_inst_err ** 2 + zp_error_bs ** 2) + + fig, ax = plt.subplots(1, 1) + ax.errorbar(x, y, yerr=yerr, fmt='k.') + ax.scatter(x[sigma_clipped], y[sigma_clipped], marker='x', c='r') + ax.plot(x, best_fit(x), color='g', linewidth=3) + ax.set_xlabel('Catalog magnitude') + ax.set_ylabel('Instrumental magnitude') + fig.savefig(os.path.join(output_folder, 'calibration.png'), bbox_inches='tight') + plt.close(fig) + + # Check that we got valid photometry: + if not np.isfinite(tab[0]['mag']) or not np.isfinite(tab[0]['mag_error']): + raise RuntimeError("Target magnitude is undefined.") + + # ============================================================================================== + # SAVE PHOTOMETRY + # ============================================================================================== + + # Descriptions of columns: + tab['used_for_epsf'].description = 'Was object used for building ePSF?' + tab['mag'].description = 'Measured magnitude' + tab['mag'].unit = u.mag + tab['mag_error'].description = 'Error on measured magnitude' + tab['mag_error'].unit = u.mag + tab['flux_aperture'].description = 'Measured flux using aperture photometry' + tab['flux_aperture'].unit = u.count / u.second + tab['flux_aperture_error'].description = 'Error on measured flux using aperture photometry' + tab['flux_aperture_error'].unit = u.count / u.second + tab['flux_psf'].description = 'Measured flux using PSF photometry' + tab['flux_psf'].unit = u.count / u.second + tab['flux_psf_error'].description = 'Error on measured flux using PSF photometry' + tab['flux_psf_error'].unit = u.count / u.second + tab['pixel_column'].description = 'Location on image pixel columns' + tab['pixel_column'].unit = u.pixel + tab['pixel_row'].description = 'Location on image pixel rows' + tab['pixel_row'].unit = u.pixel + tab['pixel_column_psf_fit'].description = 'Measured location on image pixel columns from PSF photometry' + tab['pixel_column_psf_fit'].unit = u.pixel + tab[ + 'pixel_column_psf_fit_error'].description = 'Error on measured location on image pixel columns from PSF photometry' + tab['pixel_column_psf_fit_error'].unit = u.pixel + tab['pixel_row_psf_fit'].description = 'Measured location on image pixel rows from PSF photometry' + tab['pixel_row_psf_fit'].unit = u.pixel + tab['pixel_row_psf_fit_error'].description = 'Error on measured location on image pixel rows from PSF photometry' + tab['pixel_row_psf_fit_error'].unit = u.pixel + + # Meta-data: + tab.meta['fileid'] = fileid + tab.meta['target_name'] = target_name + tab.meta['version'] = __version__ + tab.meta['template'] = None if datafile.get('template') is None else datafile['template']['fileid'] + tab.meta['diffimg'] = None if datafile.get('diffimg') is None else datafile['diffimg']['fileid'] + tab.meta['photfilter'] = photfilter + tab.meta['fwhm'] = fwhm * u.pixel + tab.meta['pixel_scale'] = pixel_scale * u.arcsec / u.pixel + tab.meta['seeing'] = (fwhm * pixel_scale) * u.arcsec + tab.meta['obstime-bmjd'] = float(image.obstime.mjd) + tab.meta['zp'] = zp_bs + tab.meta['zp_error'] = zp_error_bs + tab.meta['zp_diff'] = np.abs(zp_bs - zp) + tab.meta['zp_error_weights'] = zp_error + tab.meta['head_wcs'] = head_wcs # TODO: Are these really useful? + tab.meta['used_wcs'] = used_wcs # TODO: Are these really useful? + + # Filepath where to save photometry: + photometry_output = os.path.join(output_folder, 'photometry.ecsv') + + # Write the final table to file: + tab.write(photometry_output, format='ascii.ecsv', delimiter=',', overwrite=True) + + toc = default_timer() + + logger.info("------------------------------------------------------") + logger.info("Success!") + logger.info("Main target: %f +/- %f", tab[0]['mag'], tab[0]['mag_error']) + logger.info("Photometry took: %.1f seconds", toc - tic) + + return photometry_output diff --git a/flows/plots.py b/flows/plots.py index b051441..c5c03b3 100644 --- a/flows/plots.py +++ b/flows/plots.py @@ -26,233 +26,230 @@ matplotlib.rcParams['text.usetex'] = False matplotlib.rcParams['mathtext.fontset'] = 'dejavuserif' -#-------------------------------------------------------------------------------------------------- + +# -------------------------------------------------------------------------------------------------- def plots_interactive(backend=('Qt5Agg', 'MacOSX', 'Qt4Agg', 'Qt5Cairo', 'TkAgg')): - """ - Change plotting to using an interactive backend. + """ + Change plotting to using an interactive backend. + + Parameters: + backend (str or list): Backend to change to. If not provided, will try different + interactive backends and use the first one that works. - Parameters: - backend (str or list): Backend to change to. If not provided, will try different - interactive backends and use the first one that works. + .. codeauthor:: Rasmus Handberg + """ - .. codeauthor:: Rasmus Handberg - """ + logger = logging.getLogger(__name__) + logger.debug("Valid interactive backends: %s", matplotlib.rcsetup.interactive_bk) - logger = logging.getLogger(__name__) - logger.debug("Valid interactive backends: %s", matplotlib.rcsetup.interactive_bk) + if isinstance(backend, str): + backend = [backend] - if isinstance(backend, str): - backend = [backend] + for bckend in backend: + if bckend not in matplotlib.rcsetup.interactive_bk: + logger.warning("Interactive backend '%s' is not found", bckend) + continue - for bckend in backend: - if bckend not in matplotlib.rcsetup.interactive_bk: - logger.warning("Interactive backend '%s' is not found", bckend) - continue + # Try to change the backend, and catch errors + # it it didn't work: + try: + plt.switch_backend(bckend) + except (ModuleNotFoundError, ImportError): + pass + else: + break - # Try to change the backend, and catch errors - # it it didn't work: - try: - plt.switch_backend(bckend) - except (ModuleNotFoundError, ImportError): - pass - else: - break -#-------------------------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------------------------- def plots_noninteractive(): - """ - Change plotting to using a non-interactive backend, which can e.g. be used on a cluster. - Will set backend to 'Agg'. - - .. codeauthor:: Rasmus Handberg - """ - plt.switch_backend('Agg') - -#-------------------------------------------------------------------------------------------------- -def plot_image(image, ax=None, scale='log', cmap=None, origin='lower', xlabel=None, - ylabel=None, cbar=None, clabel='Flux ($e^{-}s^{-1}$)', cbar_ticks=None, - cbar_ticklabels=None, cbar_pad=None, cbar_size='4%', title=None, - percentile=95.0, vmin=None, vmax=None, offset_axes=None, color_bad='k', **kwargs): - """ - Utility function to plot a 2D image. - - Parameters: - image (2d array): Image data. - ax (matplotlib.pyplot.axes, optional): Axes in which to plot. - Default (None) is to use current active axes. - scale (str or :py:class:`astropy.visualization.ImageNormalize` object, optional): - Normalization used to stretch the colormap. - Options: ``'linear'``, ``'sqrt'``, ``'log'``, ``'asinh'``, ``'histeq'``, ``'sinh'`` - and ``'squared'``. - Can also be a :py:class:`astropy.visualization.ImageNormalize` object. - Default is ``'log'``. - origin (str, optional): The origin of the coordinate system. - xlabel (str, optional): Label for the x-axis. - ylabel (str, optional): Label for the y-axis. - cbar (string, optional): Location of color bar. - Choises are ``'right'``, ``'left'``, ``'top'``, ``'bottom'``. - Default is not to create colorbar. - clabel (str, optional): Label for the color bar. - cbar_size (float, optional): Fractional size of colorbar compared to axes. Default='4%'. - cbar_pad (float, optional): Padding between axes and colorbar. - title (str or None, optional): Title for the plot. - percentile (float, optional): The fraction of pixels to keep in color-trim. - If single float given, the same fraction of pixels is eliminated from both ends. - If tuple of two floats is given, the two are used as the percentiles. - Default=95. - cmap (matplotlib colormap, optional): Colormap to use. Default is the ``Blues`` colormap. - vmin (float, optional): Lower limit to use for colormap. - vmax (float, optional): Upper limit to use for colormap. - color_bad (str, optional): Color to apply to bad pixels (NaN). Default is black. - kwargs (dict, optional): Keyword arguments to be passed to :py:func:`matplotlib.pyplot.imshow`. - - Returns: - :py:class:`matplotlib.image.AxesImage`: Image from returned - by :py:func:`matplotlib.pyplot.imshow`. - - .. codeauthor:: Rasmus Handberg - """ - - logger = logging.getLogger(__name__) - - # Backward compatible settings: - make_cbar = kwargs.pop('make_cbar', None) - if make_cbar: - raise FutureWarning("'make_cbar' is deprecated. Use 'cbar' instead.") - if not cbar: - cbar = make_cbar - - # Special treatment for boolean arrays: - if isinstance(image, np.ndarray) and image.dtype == 'bool': - if vmin is None: vmin = 0 - if vmax is None: vmax = 1 - if cbar_ticks is None: cbar_ticks = [0, 1] - if cbar_ticklabels is None: cbar_ticklabels = ['False', 'True'] - - # Calculate limits of color scaling: - interval = None - if vmin is None or vmax is None: - if allnan(image): - logger.warning("Image is all NaN") - vmin = 0 - vmax = 1 - if cbar_ticks is None: - cbar_ticks = [] - if cbar_ticklabels is None: - cbar_ticklabels = [] - elif isinstance(percentile, (list, tuple, np.ndarray)): - interval = viz.AsymmetricPercentileInterval(percentile[0], percentile[1]) - else: - interval = viz.PercentileInterval(percentile) - - # Create ImageNormalize object with extracted limits: - if scale in ('log', 'linear', 'sqrt', 'asinh', 'histeq', 'sinh', 'squared'): - if scale == 'log': - stretch = viz.LogStretch() - elif scale == 'linear': - stretch = viz.LinearStretch() - elif scale == 'sqrt': - stretch = viz.SqrtStretch() - elif scale == 'asinh': - stretch = viz.AsinhStretch() - elif scale == 'histeq': - stretch = viz.HistEqStretch(image[np.isfinite(image)]) - elif scale == 'sinh': - stretch = viz.SinhStretch() - elif scale == 'squared': - stretch = viz.SquaredStretch() - - # Create ImageNormalize object. Very important to use clip=False here, otherwise - # NaN points will not be plotted correctly. - norm = viz.ImageNormalize( - data=image, - interval=interval, - vmin=vmin, - vmax=vmax, - stretch=stretch, - clip=False) - - elif isinstance(scale, (viz.ImageNormalize, matplotlib.colors.Normalize)): - norm = scale - else: - raise ValueError("scale {} is not available.".format(scale)) - - if offset_axes: - extent = (offset_axes[0]-0.5, offset_axes[0] + image.shape[1]-0.5, offset_axes[1]-0.5, offset_axes[1] + image.shape[0]-0.5) - else: - extent = (-0.5, image.shape[1]-0.5, -0.5, image.shape[0]-0.5) - - if ax is None: - ax = plt.gca() - - # Set up the colormap to use. If a bad color is defined, - # add it to the colormap: - if cmap is None: - cmap = copy.copy(plt.get_cmap('Blues')) - elif isinstance(cmap, str): - cmap = copy.copy(plt.get_cmap(cmap)) - - if color_bad: - cmap.set_bad(color_bad, 1.0) - - im = ax.imshow(image, cmap=cmap, norm=norm, origin=origin, extent=extent, interpolation='nearest', **kwargs) - if xlabel is not None: - ax.set_xlabel(xlabel) - if ylabel is not None: - ax.set_ylabel(ylabel) - if title is not None: - ax.set_title(title) - ax.set_xlim([extent[0], extent[1]]) - ax.set_ylim([extent[2], extent[3]]) - - if cbar: - fig = ax.figure - divider = make_axes_locatable(ax) - if cbar == 'top': - cbar_pad = 0.05 if cbar_pad is None else cbar_pad - cax = divider.append_axes('top', size=cbar_size, pad=cbar_pad) - orientation = 'horizontal' - elif cbar == 'bottom': - cbar_pad = 0.35 if cbar_pad is None else cbar_pad - cax = divider.append_axes('bottom', size=cbar_size, pad=cbar_pad) - orientation = 'horizontal' - elif cbar == 'left': - cbar_pad = 0.35 if cbar_pad is None else cbar_pad - cax = divider.append_axes('left', size=cbar_size, pad=cbar_pad) - orientation = 'vertical' - else: - cbar_pad = 0.05 if cbar_pad is None else cbar_pad - cax = divider.append_axes('right', size=cbar_size, pad=cbar_pad) - orientation = 'vertical' - - cb = fig.colorbar(im, cax=cax, orientation=orientation) - - if cbar == 'top': - cax.xaxis.set_ticks_position('top') - cax.xaxis.set_label_position('top') - elif cbar == 'left': - cax.yaxis.set_ticks_position('left') - cax.yaxis.set_label_position('left') - - if clabel is not None: - cb.set_label(clabel) - if cbar_ticks is not None: - cb.set_ticks(cbar_ticks) - if cbar_ticklabels is not None: - cb.set_ticklabels(cbar_ticklabels) - - #cax.yaxis.set_major_locator(matplotlib.ticker.AutoLocator()) - #cax.yaxis.set_minor_locator(matplotlib.ticker.AutoLocator()) - cax.tick_params(which='both', direction='out', pad=5) - - # Settings for ticks: - integer_locator = MaxNLocator(nbins=10, integer=True) - ax.xaxis.set_major_locator(integer_locator) - ax.xaxis.set_minor_locator(integer_locator) - ax.yaxis.set_major_locator(integer_locator) - ax.yaxis.set_minor_locator(integer_locator) - ax.tick_params(which='both', direction='out', pad=5) - ax.xaxis.tick_bottom() - ax.yaxis.tick_left() - - return im + """ + Change plotting to using a non-interactive backend, which can e.g. be used on a cluster. + Will set backend to 'Agg'. + + .. codeauthor:: Rasmus Handberg + """ + plt.switch_backend('Agg') + + +# -------------------------------------------------------------------------------------------------- +def plot_image(image, ax=None, scale='log', cmap=None, origin='lower', xlabel=None, ylabel=None, cbar=None, + clabel='Flux ($e^{-}s^{-1}$)', cbar_ticks=None, cbar_ticklabels=None, cbar_pad=None, cbar_size='4%', + title=None, percentile=95.0, vmin=None, vmax=None, offset_axes=None, color_bad='k', **kwargs): + """ + Utility function to plot a 2D image. + + Parameters: + image (2d array): Image data. + ax (matplotlib.pyplot.axes, optional): Axes in which to plot. + Default (None) is to use current active axes. + scale (str or :py:class:`astropy.visualization.ImageNormalize` object, optional): + Normalization used to stretch the colormap. + Options: ``'linear'``, ``'sqrt'``, ``'log'``, ``'asinh'``, ``'histeq'``, ``'sinh'`` + and ``'squared'``. + Can also be a :py:class:`astropy.visualization.ImageNormalize` object. + Default is ``'log'``. + origin (str, optional): The origin of the coordinate system. + xlabel (str, optional): Label for the x-axis. + ylabel (str, optional): Label for the y-axis. + cbar (string, optional): Location of color bar. + Choises are ``'right'``, ``'left'``, ``'top'``, ``'bottom'``. + Default is not to create colorbar. + clabel (str, optional): Label for the color bar. + cbar_size (float, optional): Fractional size of colorbar compared to axes. Default='4%'. + cbar_pad (float, optional): Padding between axes and colorbar. + title (str or None, optional): Title for the plot. + percentile (float, optional): The fraction of pixels to keep in color-trim. + If single float given, the same fraction of pixels is eliminated from both ends. + If tuple of two floats is given, the two are used as the percentiles. + Default=95. + cmap (matplotlib colormap, optional): Colormap to use. Default is the ``Blues`` colormap. + vmin (float, optional): Lower limit to use for colormap. + vmax (float, optional): Upper limit to use for colormap. + color_bad (str, optional): Color to apply to bad pixels (NaN). Default is black. + kwargs (dict, optional): Keyword arguments to be passed to :py:func:`matplotlib.pyplot.imshow`. + + Returns: + :py:class:`matplotlib.image.AxesImage`: Image from returned + by :py:func:`matplotlib.pyplot.imshow`. + + .. codeauthor:: Rasmus Handberg + """ + + logger = logging.getLogger(__name__) + + # Backward compatible settings: + make_cbar = kwargs.pop('make_cbar', None) + if make_cbar: + raise FutureWarning("'make_cbar' is deprecated. Use 'cbar' instead.") + if not cbar: + cbar = make_cbar + + # Special treatment for boolean arrays: + if isinstance(image, np.ndarray) and image.dtype == 'bool': + if vmin is None: vmin = 0 + if vmax is None: vmax = 1 + if cbar_ticks is None: cbar_ticks = [0, 1] + if cbar_ticklabels is None: cbar_ticklabels = ['False', 'True'] + + # Calculate limits of color scaling: + interval = None + if vmin is None or vmax is None: + if allnan(image): + logger.warning("Image is all NaN") + vmin = 0 + vmax = 1 + if cbar_ticks is None: + cbar_ticks = [] + if cbar_ticklabels is None: + cbar_ticklabels = [] + elif isinstance(percentile, (list, tuple, np.ndarray)): + interval = viz.AsymmetricPercentileInterval(percentile[0], percentile[1]) + else: + interval = viz.PercentileInterval(percentile) + + # Create ImageNormalize object with extracted limits: + if scale in ('log', 'linear', 'sqrt', 'asinh', 'histeq', 'sinh', 'squared'): + if scale == 'log': + stretch = viz.LogStretch() + elif scale == 'linear': + stretch = viz.LinearStretch() + elif scale == 'sqrt': + stretch = viz.SqrtStretch() + elif scale == 'asinh': + stretch = viz.AsinhStretch() + elif scale == 'histeq': + stretch = viz.HistEqStretch(image[np.isfinite(image)]) + elif scale == 'sinh': + stretch = viz.SinhStretch() + elif scale == 'squared': + stretch = viz.SquaredStretch() + + # Create ImageNormalize object. Very important to use clip=False here, otherwise + # NaN points will not be plotted correctly. + norm = viz.ImageNormalize(data=image, interval=interval, vmin=vmin, vmax=vmax, stretch=stretch, clip=False) + + elif isinstance(scale, (viz.ImageNormalize, matplotlib.colors.Normalize)): + norm = scale + else: + raise ValueError("scale {} is not available.".format(scale)) + + if offset_axes: + extent = (offset_axes[0] - 0.5, offset_axes[0] + image.shape[1] - 0.5, offset_axes[1] - 0.5, + offset_axes[1] + image.shape[0] - 0.5) + else: + extent = (-0.5, image.shape[1] - 0.5, -0.5, image.shape[0] - 0.5) + + if ax is None: + ax = plt.gca() + + # Set up the colormap to use. If a bad color is defined, + # add it to the colormap: + if cmap is None: + cmap = copy.copy(plt.get_cmap('Blues')) + elif isinstance(cmap, str): + cmap = copy.copy(plt.get_cmap(cmap)) + + if color_bad: + cmap.set_bad(color_bad, 1.0) + + im = ax.imshow(image, cmap=cmap, norm=norm, origin=origin, extent=extent, interpolation='nearest', **kwargs) + if xlabel is not None: + ax.set_xlabel(xlabel) + if ylabel is not None: + ax.set_ylabel(ylabel) + if title is not None: + ax.set_title(title) + ax.set_xlim([extent[0], extent[1]]) + ax.set_ylim([extent[2], extent[3]]) + + if cbar: + fig = ax.figure + divider = make_axes_locatable(ax) + if cbar == 'top': + cbar_pad = 0.05 if cbar_pad is None else cbar_pad + cax = divider.append_axes('top', size=cbar_size, pad=cbar_pad) + orientation = 'horizontal' + elif cbar == 'bottom': + cbar_pad = 0.35 if cbar_pad is None else cbar_pad + cax = divider.append_axes('bottom', size=cbar_size, pad=cbar_pad) + orientation = 'horizontal' + elif cbar == 'left': + cbar_pad = 0.35 if cbar_pad is None else cbar_pad + cax = divider.append_axes('left', size=cbar_size, pad=cbar_pad) + orientation = 'vertical' + else: + cbar_pad = 0.05 if cbar_pad is None else cbar_pad + cax = divider.append_axes('right', size=cbar_size, pad=cbar_pad) + orientation = 'vertical' + + cb = fig.colorbar(im, cax=cax, orientation=orientation) + + if cbar == 'top': + cax.xaxis.set_ticks_position('top') + cax.xaxis.set_label_position('top') + elif cbar == 'left': + cax.yaxis.set_ticks_position('left') + cax.yaxis.set_label_position('left') + + if clabel is not None: + cb.set_label(clabel) + if cbar_ticks is not None: + cb.set_ticks(cbar_ticks) + if cbar_ticklabels is not None: + cb.set_ticklabels(cbar_ticklabels) + + # cax.yaxis.set_major_locator(matplotlib.ticker.AutoLocator()) + # cax.yaxis.set_minor_locator(matplotlib.ticker.AutoLocator()) + cax.tick_params(which='both', direction='out', pad=5) + + # Settings for ticks: + integer_locator = MaxNLocator(nbins=10, integer=True) + ax.xaxis.set_major_locator(integer_locator) + ax.xaxis.set_minor_locator(integer_locator) + ax.yaxis.set_major_locator(integer_locator) + ax.yaxis.set_minor_locator(integer_locator) + ax.tick_params(which='both', direction='out', pad=5) + ax.xaxis.tick_bottom() + ax.yaxis.tick_left() + + return im diff --git a/flows/reference_cleaning.py b/flows/reference_cleaning.py index 5d2c689..27daac5 100644 --- a/flows/reference_cleaning.py +++ b/flows/reference_cleaning.py @@ -18,313 +18,300 @@ from scipy.spatial import KDTree import pandas as pd # TODO: Convert to pure numpy implementation -#-------------------------------------------------------------------------------------------------- + +# -------------------------------------------------------------------------------------------------- class MinStarError(RuntimeError): - pass - -#-------------------------------------------------------------------------------------------------- -def force_reject_g2d(xarray, yarray, image, get_fwhm=True, rsq_min=0.5, radius=10, fwhm_guess=6.0, - fwhm_min=3.5, fwhm_max=18.0): - """ - - Parameters: - xarray: - yarray: - image: - get_fwhm (bool, optional): - rsq_min (float, optional): - radius (float, optional): - fwhm_guess=6.0: - fwhm_min=3.5: - fwhm_max=18.0: - - Returns: - tuple: - - masked_xys: - - mask: - - masked_rsqs: - - .. codeauthor:: Emir Karamehmetoglu - .. codeauthor:: Rasmus Handberg - """ - # Set up 2D Gaussian model for fitting to reference stars: - g2d = models.Gaussian2D(amplitude=1.0, - x_mean=radius, - y_mean=radius, - x_stddev=fwhm_guess * gaussian_fwhm_to_sigma) - g2d.amplitude.bounds = (0.1, 2.0) - g2d.x_mean.bounds = (0.5 * radius, 1.5 * radius) - g2d.y_mean.bounds = (0.5 * radius, 1.5 * radius) - g2d.x_stddev.bounds = ( - fwhm_min * gaussian_fwhm_to_sigma, - fwhm_max * gaussian_fwhm_to_sigma - ) - g2d.y_stddev.tied = lambda model: model.x_stddev - g2d.theta.fixed = True - - gfitter = fitting.LevMarLSQFitter() - - # Stars reject - N = len(xarray) - fwhms = np.full((N, 2), np.NaN) - xys = np.full((N, 2), np.NaN) - rsqs = np.full(N, np.NaN) - for i, (x, y) in enumerate(zip(xarray, yarray)): - x = int(np.round(x)) - y = int(np.round(y)) - xmin = max(x - radius, 0) - xmax = min(x + radius + 1, image.shape[1]) - ymin = max(y - radius, 0) - ymax = min(y + radius + 1, image.shape[0]) - - curr_star = deepcopy(image.subclean[ymin:ymax, xmin:xmax]) - - edge = np.zeros_like(curr_star, dtype='bool') - edge[(0, -1), :] = True - edge[:, (0, -1)] = True - curr_star -= nanmedian(curr_star[edge]) - curr_star /= np.nanmax(curr_star) - - ypos, xpos = np.mgrid[:curr_star.shape[0], :curr_star.shape[1]] - gfit = gfitter(g2d, x=xpos, y=ypos, z=curr_star) - - # Center - xys[i] = np.array([gfit.x_mean + x - radius, gfit.y_mean + y - radius], dtype='float64') - - # Calculate rsq - sstot = nansum((curr_star - nanmean(curr_star))**2) - sserr = nansum(gfitter.fit_info['fvec']**2) - rsqs[i] = 0 if sstot == 0 else 1.0 - (sserr / sstot) - - # FWHM - fwhms[i] = gfit.x_fwhm - - masked_xys = np.ma.masked_array(xys, ~np.isfinite(xys)) - masked_rsqs = np.ma.masked_array(rsqs, ~np.isfinite(rsqs)) - mask = (masked_rsqs >= rsq_min) & (masked_rsqs < 1.0) # Reject Rsq < rsq_min - # changed - #masked_xys = masked_xys[mask] # Clean extracted array. - # to - masked_xys.mask[~mask] = True - # don't know if it breaks anything, but it doesn't make sence if - # len(masked_xys) != len(masked_rsqs) FIXME - masked_fwhms = np.ma.masked_array(fwhms, ~np.isfinite(fwhms)) - - if get_fwhm: - return masked_fwhms, masked_xys, mask, masked_rsqs - return masked_xys, mask, masked_rsqs - -#-------------------------------------------------------------------------------------------------- -def clean_with_rsq_and_get_fwhm(masked_fwhms, masked_rsqs, references, min_fwhm_references=2, - min_references=6, rsq_min=0.15): - """ - Clean references and obtain fwhm using RSQ values. - - Parameters: - masked_fwhms (np.ma.maskedarray): array of fwhms - masked_rsqs (np.ma.maskedarray): array of rsq values - references (astropy.table.Table): table or reference stars - min_fwhm_references: (Default 2) min stars to get a fwhm - min_references: (Default 6) min stars to aim for when cutting by R2 - rsq_min: (Default 0.15) min rsq value - - .. codeauthor:: Emir Karamehmetoglu - """ - min_references_now = min_references - rsqvals = np.arange(rsq_min, 0.95, 0.15)[::-1] - fwhm_found = False - min_references_achieved = False - - # Clean based on R^2 Value - while not min_references_achieved: - for rsqval in rsqvals: - mask = (masked_rsqs >= rsqval) & (masked_rsqs < 1.0) - nreferences = np.sum(np.isfinite(masked_fwhms[mask])) - if nreferences >= min_fwhm_references: - _fwhms_cut_ = np.nanmean(sigma_clip(masked_fwhms[mask], maxiters=100, sigma=2.0)) - if not fwhm_found: - fwhm = _fwhms_cut_ - fwhm_found = True - if nreferences >= min_references_now: - references = references[mask] - min_references_achieved = True - break - if min_references_achieved: - break - min_references_now = min_references_now - 2 - if (min_references_now < 2) and fwhm_found: - break - elif not fwhm_found: - raise RuntimeError("Could not estimate FWHM") - - if np.isnan(fwhm): - raise RuntimeError("Could not estimate FWHM") - - # if minimum references not found, then take what we can get with even a weaker cut. - # TODO: Is this right, or should we grab rsq_min (or even weaker?) - min_references_now = min_references - 2 - while not min_references_achieved: - mask = (masked_rsqs >= rsq_min) & (masked_rsqs < 1.0) - nreferences = np.sum(np.isfinite(masked_fwhms[mask])) - if nreferences >= min_references_now: - references = references[mask] - min_references_achieved = True - rsq_min = rsq_min - 0.07 - min_references_now = min_references_now - 1 - - # Check len of references as this is a destructive cleaning. - # if len(references) == 2: logger.info('2 reference stars remaining, check WCS and image quality') - if len(references) < 2: - raise RuntimeError(f"{len(references)} References remaining; could not clean.") - return fwhm, references - -#-------------------------------------------------------------------------------------------------- + pass + + +# -------------------------------------------------------------------------------------------------- +def force_reject_g2d(xarray, yarray, image, get_fwhm=True, rsq_min=0.5, radius=10, fwhm_guess=6.0, fwhm_min=3.5, + fwhm_max=18.0): + """ + + Parameters: + xarray: + yarray: + image: + get_fwhm (bool, optional): + rsq_min (float, optional): + radius (float, optional): + fwhm_guess=6.0: + fwhm_min=3.5: + fwhm_max=18.0: + + Returns: + tuple: + - masked_xys: + - mask: + - masked_rsqs: + + .. codeauthor:: Emir Karamehmetoglu + .. codeauthor:: Rasmus Handberg + """ + # Set up 2D Gaussian model for fitting to reference stars: + g2d = models.Gaussian2D(amplitude=1.0, x_mean=radius, y_mean=radius, x_stddev=fwhm_guess * gaussian_fwhm_to_sigma) + g2d.amplitude.bounds = (0.1, 2.0) + g2d.x_mean.bounds = (0.5 * radius, 1.5 * radius) + g2d.y_mean.bounds = (0.5 * radius, 1.5 * radius) + g2d.x_stddev.bounds = (fwhm_min * gaussian_fwhm_to_sigma, fwhm_max * gaussian_fwhm_to_sigma) + g2d.y_stddev.tied = lambda model: model.x_stddev + g2d.theta.fixed = True + + gfitter = fitting.LevMarLSQFitter() + + # Stars reject + N = len(xarray) + fwhms = np.full((N, 2), np.NaN) + xys = np.full((N, 2), np.NaN) + rsqs = np.full(N, np.NaN) + for i, (x, y) in enumerate(zip(xarray, yarray)): + x = int(np.round(x)) + y = int(np.round(y)) + xmin = max(x - radius, 0) + xmax = min(x + radius + 1, image.shape[1]) + ymin = max(y - radius, 0) + ymax = min(y + radius + 1, image.shape[0]) + + curr_star = deepcopy(image.subclean[ymin:ymax, xmin:xmax]) + + edge = np.zeros_like(curr_star, dtype='bool') + edge[(0, -1), :] = True + edge[:, (0, -1)] = True + curr_star -= nanmedian(curr_star[edge]) + curr_star /= np.nanmax(curr_star) + + ypos, xpos = np.mgrid[:curr_star.shape[0], :curr_star.shape[1]] + gfit = gfitter(g2d, x=xpos, y=ypos, z=curr_star) + + # Center + xys[i] = np.array([gfit.x_mean + x - radius, gfit.y_mean + y - radius], dtype='float64') + + # Calculate rsq + sstot = nansum((curr_star - nanmean(curr_star)) ** 2) + sserr = nansum(gfitter.fit_info['fvec'] ** 2) + rsqs[i] = 0 if sstot == 0 else 1.0 - (sserr / sstot) + + # FWHM + fwhms[i] = gfit.x_fwhm + + masked_xys = np.ma.masked_array(xys, ~np.isfinite(xys)) + masked_rsqs = np.ma.masked_array(rsqs, ~np.isfinite(rsqs)) + mask = (masked_rsqs >= rsq_min) & (masked_rsqs < 1.0) # Reject Rsq < rsq_min + # changed + # masked_xys = masked_xys[mask] # Clean extracted array. + # to + masked_xys.mask[~mask] = True + # don't know if it breaks anything, but it doesn't make sence if + # len(masked_xys) != len(masked_rsqs) FIXME + masked_fwhms = np.ma.masked_array(fwhms, ~np.isfinite(fwhms)) + + if get_fwhm: + return masked_fwhms, masked_xys, mask, masked_rsqs + return masked_xys, mask, masked_rsqs + + +# -------------------------------------------------------------------------------------------------- +def clean_with_rsq_and_get_fwhm(masked_fwhms, masked_rsqs, references, min_fwhm_references=2, min_references=6, + rsq_min=0.15): + """ + Clean references and obtain fwhm using RSQ values. + + Parameters: + masked_fwhms (np.ma.maskedarray): array of fwhms + masked_rsqs (np.ma.maskedarray): array of rsq values + references (astropy.table.Table): table or reference stars + min_fwhm_references: (Default 2) min stars to get a fwhm + min_references: (Default 6) min stars to aim for when cutting by R2 + rsq_min: (Default 0.15) min rsq value + + .. codeauthor:: Emir Karamehmetoglu + """ + min_references_now = min_references + rsqvals = np.arange(rsq_min, 0.95, 0.15)[::-1] + fwhm_found = False + min_references_achieved = False + + # Clean based on R^2 Value + while not min_references_achieved: + for rsqval in rsqvals: + mask = (masked_rsqs >= rsqval) & (masked_rsqs < 1.0) + nreferences = np.sum(np.isfinite(masked_fwhms[mask])) + if nreferences >= min_fwhm_references: + _fwhms_cut_ = np.nanmean(sigma_clip(masked_fwhms[mask], maxiters=100, sigma=2.0)) + if not fwhm_found: + fwhm = _fwhms_cut_ + fwhm_found = True + if nreferences >= min_references_now: + references = references[mask] + min_references_achieved = True + break + if min_references_achieved: + break + min_references_now = min_references_now - 2 + if (min_references_now < 2) and fwhm_found: + break + elif not fwhm_found: + raise RuntimeError("Could not estimate FWHM") + + if np.isnan(fwhm): + raise RuntimeError("Could not estimate FWHM") + + # if minimum references not found, then take what we can get with even a weaker cut. + # TODO: Is this right, or should we grab rsq_min (or even weaker?) + min_references_now = min_references - 2 + while not min_references_achieved: + mask = (masked_rsqs >= rsq_min) & (masked_rsqs < 1.0) + nreferences = np.sum(np.isfinite(masked_fwhms[mask])) + if nreferences >= min_references_now: + references = references[mask] + min_references_achieved = True + rsq_min = rsq_min - 0.07 + min_references_now = min_references_now - 1 + + # Check len of references as this is a destructive cleaning. + # if len(references) == 2: logger.info('2 reference stars remaining, check WCS and image quality') + if len(references) < 2: + raise RuntimeError(f"{len(references)} References remaining; could not clean.") + return fwhm, references + + +# -------------------------------------------------------------------------------------------------- def mkposxy(posx, posy): - '''Make 2D np array for astroalign''' - img_posxy = np.array([[x, y] for x, y in zip(posx, posy)], dtype="float64") - return img_posxy + '''Make 2D np array for astroalign''' + img_posxy = np.array([[x, y] for x, y in zip(posx, posy)], dtype="float64") + return img_posxy + -#-------------------------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------------------------- def try_transform(source, target, pixeltol=2, nnearest=5, max_stars=50): - aa.NUM_NEAREST_NEIGHBORS = nnearest - aa.PIXEL_TOL = pixeltol - transform, (sourcestars, targetstars) = aa.find_transform( - source, - target, - max_control_points=max_stars) - return sourcestars, targetstars - -#-------------------------------------------------------------------------------------------------- + aa.NUM_NEAREST_NEIGHBORS = nnearest + aa.PIXEL_TOL = pixeltol + transform, (sourcestars, targetstars) = aa.find_transform(source, target, max_control_points=max_stars) + return sourcestars, targetstars + + +# -------------------------------------------------------------------------------------------------- def try_astroalign(source, target, pixeltol=2, nnearest=5, max_stars_n=50): - # Get indexes of matched stars - success = False - try: - source_stars, target_stars = try_transform( - source, - target, - pixeltol=pixeltol, - nnearest=nnearest, - max_stars=max_stars_n) - source_ind = np.argwhere(np.in1d(source, source_stars)[::2]).flatten() - target_ind = np.argwhere(np.in1d(target, target_stars)[::2]).flatten() - success = True - except aa.MaxIterError: - source_ind, target_ind = 'None', 'None' - return source_ind, target_ind, success - -#-------------------------------------------------------------------------------------------------- -def min_to_max_astroalign(source, target, fwhm=5, fwhm_min=1, fwhm_max=4, knn_min=5, - knn_max=20, max_stars=100, min_matches=3): - """Try to find matches using astroalign asterisms by stepping through some parameters.""" - # Set max_control_points par based on number of stars and max_stars. - nstars = max(len(source), len(source)) - if max_stars >= nstars: - max_stars_list = 'None' - else: - if max_stars > 60: - max_stars_list = (max_stars, 50, 4, 3) - else: - max_stars_list = (max_stars, 6, 4, 3) - - # Create max_stars step-through list if not given - if max_stars_list == 'None': - if nstars > 6: - max_stars_list = (nstars, 5, 3) - elif nstars > 3: - max_stars_list = (nstars, 3) - - pixeltols = np.linspace(int(fwhm * fwhm_min), int(fwhm * fwhm_max), 4, dtype=int) - nearest_neighbors = np.linspace(knn_min, min(knn_max, nstars), 4, dtype=int) - - for max_stars_n in max_stars_list: - for pixeltol in pixeltols: - for nnearest in nearest_neighbors: - source_ind, target_ind, success = try_astroalign(source, target, - pixeltol=pixeltol, - nnearest=nnearest, - max_stars_n=max_stars_n) - if success: - if len(source_ind) >= min_matches: - return source_ind, target_ind, success - else: - success = False - return 'None', 'None', success - -#-------------------------------------------------------------------------------------------------- + # Get indexes of matched stars + success = False + try: + source_stars, target_stars = try_transform(source, target, pixeltol=pixeltol, nnearest=nnearest, + max_stars=max_stars_n) + source_ind = np.argwhere(np.in1d(source, source_stars)[::2]).flatten() + target_ind = np.argwhere(np.in1d(target, target_stars)[::2]).flatten() + success = True + except aa.MaxIterError: + source_ind, target_ind = 'None', 'None' + return source_ind, target_ind, success + + +# -------------------------------------------------------------------------------------------------- +def min_to_max_astroalign(source, target, fwhm=5, fwhm_min=1, fwhm_max=4, knn_min=5, knn_max=20, max_stars=100, + min_matches=3): + """Try to find matches using astroalign asterisms by stepping through some parameters.""" + # Set max_control_points par based on number of stars and max_stars. + nstars = max(len(source), len(source)) + if max_stars >= nstars: + max_stars_list = 'None' + else: + if max_stars > 60: + max_stars_list = (max_stars, 50, 4, 3) + else: + max_stars_list = (max_stars, 6, 4, 3) + + # Create max_stars step-through list if not given + if max_stars_list == 'None': + if nstars > 6: + max_stars_list = (nstars, 5, 3) + elif nstars > 3: + max_stars_list = (nstars, 3) + + pixeltols = np.linspace(int(fwhm * fwhm_min), int(fwhm * fwhm_max), 4, dtype=int) + nearest_neighbors = np.linspace(knn_min, min(knn_max, nstars), 4, dtype=int) + + for max_stars_n in max_stars_list: + for pixeltol in pixeltols: + for nnearest in nearest_neighbors: + source_ind, target_ind, success = try_astroalign(source, target, pixeltol=pixeltol, nnearest=nnearest, + max_stars_n=max_stars_n) + if success: + if len(source_ind) >= min_matches: + return source_ind, target_ind, success + else: + success = False + return 'None', 'None', success + + +# -------------------------------------------------------------------------------------------------- def kdtree(source, target, fwhm=5, fwhm_max=4, min_matches=3): - '''Use KDTree to get nearest neighbor matches within fwhm_max*fwhm distance''' - - # Use KDTree to rapidly efficiently query nearest neighbors - - tt = KDTree(target) - st = KDTree(source) - matches_list = st.query_ball_tree(tt, r=fwhm * fwhm_max) - - #indx = [] - targets = [] - sources = [] - for j, (sstar, match) in enumerate(zip(source, matches_list)): - if np.array(target[match]).size != 0: - targets.append(match[0]) - sources.append(j) - sources = np.array(sources, dtype=int) - targets = np.array(targets, dtype=int) - - # Return indexes of matches - return sources, targets, len(sources) >= min_matches - -#-------------------------------------------------------------------------------------------------- -def get_new_wcs(extracted_ind, extracted_stars, clean_references, ref_ind, obstime, rakey='ra_obs', - deckey='decl_obs'): - - targets = (extracted_stars[extracted_ind][:, 0], extracted_stars[extracted_ind][:, 1]) - - c = SkyCoord( - ra=clean_references[rakey][ref_ind], - dec=clean_references[deckey][ref_ind], - frame='icrs', - obstime=obstime - ) - return wcs.utils.fit_wcs_from_points(targets, c) - -#-------------------------------------------------------------------------------------------------- -def get_clean_references(references, masked_rsqs, min_references_ideal=6, min_references_abs=3, - rsq_min=0.15, rsq_ideal=0.5, keep_max=100, rescue_bad: bool = True): - - # Greedy first try - mask = (masked_rsqs >= rsq_ideal) & (masked_rsqs < 1.0) - if np.sum(np.isfinite(masked_rsqs[mask])) >= min_references_ideal: - if len(references[mask]) <= keep_max: - return references[mask] - elif len(references[mask]) >= keep_max: - - df = pd.DataFrame(masked_rsqs, columns=['rsq']) - masked_rsqs.mask = ~mask - nmasked_rsqs = df.sort_values('rsq', ascending=False).dropna().index._data - return references[nmasked_rsqs[:keep_max]] - - # Desperate second try - mask = (masked_rsqs >= rsq_min) & (masked_rsqs < 1.0) - masked_rsqs.mask = ~mask - - # Switching to pandas for easier selection - df = pd.DataFrame(masked_rsqs, columns=['rsq']) - nmasked_rsqs = deepcopy( - df.sort_values('rsq', ascending=False).dropna().index._data) - nmasked_rsqs = nmasked_rsqs[:min(min_references_ideal, len(nmasked_rsqs))] - if len(nmasked_rsqs) >= min_references_abs: - return references[nmasked_rsqs] - if not rescue_bad: - raise MinStarError(f'Less than {min_references_abs} clean stars and rescue_bad = False') - - # Extremely desperate last ditch attempt i.e. "rescue bad" - mask = (masked_rsqs >= 0.02) & (masked_rsqs < 1.0) - masked_rsqs.mask = ~mask - - # Switch to pandas - df = pd.DataFrame(masked_rsqs, columns=['rsq']) - nmasked_rsqs = df.sort_values('rsq', ascending=False).dropna().index._data - nmasked_rsqs = nmasked_rsqs[:min(min_references_ideal, len(nmasked_rsqs))] - if len(nmasked_rsqs) < 2: - raise MinStarError('Less than 2 clean stars.') - return references[nmasked_rsqs] # Return if len >= 2 + '''Use KDTree to get nearest neighbor matches within fwhm_max*fwhm distance''' + + # Use KDTree to rapidly efficiently query nearest neighbors + + tt = KDTree(target) + st = KDTree(source) + matches_list = st.query_ball_tree(tt, r=fwhm * fwhm_max) + + # indx = [] + targets = [] + sources = [] + for j, (sstar, match) in enumerate(zip(source, matches_list)): + if np.array(target[match]).size != 0: + targets.append(match[0]) + sources.append(j) + sources = np.array(sources, dtype=int) + targets = np.array(targets, dtype=int) + + # Return indexes of matches + return sources, targets, len(sources) >= min_matches + + +# -------------------------------------------------------------------------------------------------- +def get_new_wcs(extracted_ind, extracted_stars, clean_references, ref_ind, obstime, rakey='ra_obs', deckey='decl_obs'): + targets = (extracted_stars[extracted_ind][:, 0], extracted_stars[extracted_ind][:, 1]) + + c = SkyCoord(ra=clean_references[rakey][ref_ind], dec=clean_references[deckey][ref_ind], frame='icrs', + obstime=obstime) + return wcs.utils.fit_wcs_from_points(targets, c) + + +# -------------------------------------------------------------------------------------------------- +def get_clean_references(references, masked_rsqs, min_references_ideal=6, min_references_abs=3, rsq_min=0.15, + rsq_ideal=0.5, keep_max=100, rescue_bad: bool = True): + # Greedy first try + mask = (masked_rsqs >= rsq_ideal) & (masked_rsqs < 1.0) + if np.sum(np.isfinite(masked_rsqs[mask])) >= min_references_ideal: + if len(references[mask]) <= keep_max: + return references[mask] + elif len(references[mask]) >= keep_max: + + df = pd.DataFrame(masked_rsqs, columns=['rsq']) + masked_rsqs.mask = ~mask + nmasked_rsqs = df.sort_values('rsq', ascending=False).dropna().index._data + return references[nmasked_rsqs[:keep_max]] + + # Desperate second try + mask = (masked_rsqs >= rsq_min) & (masked_rsqs < 1.0) + masked_rsqs.mask = ~mask + + # Switching to pandas for easier selection + df = pd.DataFrame(masked_rsqs, columns=['rsq']) + nmasked_rsqs = deepcopy(df.sort_values('rsq', ascending=False).dropna().index._data) + nmasked_rsqs = nmasked_rsqs[:min(min_references_ideal, len(nmasked_rsqs))] + if len(nmasked_rsqs) >= min_references_abs: + return references[nmasked_rsqs] + if not rescue_bad: + raise MinStarError(f'Less than {min_references_abs} clean stars and rescue_bad = False') + + # Extremely desperate last ditch attempt i.e. "rescue bad" + mask = (masked_rsqs >= 0.02) & (masked_rsqs < 1.0) + masked_rsqs.mask = ~mask + + # Switch to pandas + df = pd.DataFrame(masked_rsqs, columns=['rsq']) + nmasked_rsqs = df.sort_values('rsq', ascending=False).dropna().index._data + nmasked_rsqs = nmasked_rsqs[:min(min_references_ideal, len(nmasked_rsqs))] + if len(nmasked_rsqs) < 2: + raise MinStarError('Less than 2 clean stars.') + return references[nmasked_rsqs] # Return if len >= 2 diff --git a/flows/run_imagematch.py b/flows/run_imagematch.py index 5462905..418b696 100644 --- a/flows/run_imagematch.py +++ b/flows/run_imagematch.py @@ -17,18 +17,18 @@ import re from astropy.io import fits from astropy.wcs.utils import proj_plane_pixel_area -#from setuptools import Distribution -#from setuptools.command.install import install +from tendrils import api from .load_image import load_image -from . import api -#-------------------------------------------------------------------------------------------------- -#class OnlyGetScriptPath(install): + + +# -------------------------------------------------------------------------------------------------- +# class OnlyGetScriptPath(install): # def run(self): # # does not call install.run() by design # self.distribution.install_scripts = self.install_scripts -#def get_setuptools_script_dir(): +# def get_setuptools_script_dir(): # dist = Distribution({'cmdclass': {'install': OnlyGetScriptPath}}) # dist.dry_run = True # not sure if necessary, but to be safe # dist.parse_config_files() @@ -37,152 +37,140 @@ # command.run() # return dist.install_scripts -#-------------------------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------------------------- def run_imagematch(datafile, target=None, star_coord=None, fwhm=None, pixel_scale=None): - """ - Run ImageMatch on a datafile. - - Parameters: - datafile (dict): Data file to run ImageMatch on. - target (:class:`astropy.table.Table`, optional): Target informaton. - - .. codeauthor:: Rasmus Handberg - """ - - logger = logging.getLogger(__name__) - - if datafile.get('template') is None: - raise ValueError("DATAFILE input does not specify a template to use.") - - # Extract paths to science and reference images: - reference_image = os.path.join(datafile['archive_path'], datafile['template']['path']) - science_image = os.path.join(datafile['archive_path'], datafile['path']) - - # If the target was not provided in the function call, - # use the API to get the target information: - if target is None: - catalog = api.get_catalog(datafile['targetid'], output='table') - target = catalog['target'][0] - - # Find the path to where the ImageMatch program is installed. - # This is to avoid problems with it not being on the users PATH - # and if the user is using some other version of the python executable. - # TODO: There must be a better way of doing this! - #imgmatch = os.path.join(get_setuptools_script_dir(), 'ImageMatch') - if os.name == "nt": - out = subprocess.check_output(["where", "ImageMatch"], universal_newlines=True) - imgmatch = out.strip() - else: - out = subprocess.check_output(["whereis", "ImageMatch"], universal_newlines=True) - out = re.match('ImageMatch: (.+)', out.strip()) - imgmatch = out.group(1) - - if not os.path.isfile(imgmatch): - raise FileNotFoundError("ImageMatch not found") - - # Find the ImageMatch config file to use based on the site of the observations: - __dir__ = os.path.dirname(os.path.abspath(__file__)) - if datafile['site'] in (1,3,4,6): - config_file = os.path.join(__dir__, 'imagematch', 'imagematch_lcogt.cfg') - elif datafile['site'] == 2: - config_file = os.path.join(__dir__, 'imagematch', 'imagematch_hawki.cfg') - elif datafile['site'] == 5: - config_file = os.path.join(__dir__, 'imagematch', 'imagematch_alfosc.cfg') - else: - config_file = os.path.join(__dir__, 'imagematch', 'imagematch_default.cfg') - if not os.path.isfile(config_file): - raise FileNotFoundError(config_file) - - if pixel_scale is None: - if datafile['site'] in (1,3,4,6): - # LCOGT provides the pixel scale directly in the header - pixel_scale = 'PIXSCALE' - else: - image = load_image(science_image) - pixel_area = proj_plane_pixel_area(image.wcs) - pixel_scale = np.sqrt(pixel_area)*3600 # arcsec/pixel - logger.info("Calculated science image pixel scale: %f", pixel_scale) - - if datafile['template']['site'] in (1,3,4,6): - # LCOGT provides the pixel scale directly in the header - mscale = 'PIXSCALE' - else: - template = load_image(reference_image) - template_pixel_area = proj_plane_pixel_area(template.wcs.celestial) - mscale = np.sqrt(template_pixel_area)*3600 # arcsec/pixel - logger.info("Calculated template pixel scale: %f", mscale) - - # Scale kernel radius with FWHM: - if fwhm is None: - kernel_radius = 9 - else: - kernel_radius = max(9, int(np.ceil(1.5*fwhm))) - if kernel_radius % 2 == 0: - kernel_radius += 1 - - # We will work in a temporary directory, since ImageMatch produces - # a lot of extra output files that we don't want to have lying around - # after it completes - with tempfile.TemporaryDirectory() as tmpdir: - - # Copy the science and reference image to the temp dir: - shutil.copy(reference_image, tmpdir) - shutil.copy(science_image, tmpdir) - - # Construct the command to run ImageMatch: - for match_threshold in (3.0, 5.0, 7.0, 10.0): - cmd = '"{python:s}" "{imgmatch:s}" -cfg "{config_file:s}" -snx {target_ra:.10f}d -sny {target_dec:.10f}d -p {kernel_radius:d} -o {order:d} -s {match:f} -scale {pixel_scale:} -mscale {mscale:} -m "{reference_image:s}" "{science_image:s}"'.format( - python=sys.executable, - imgmatch=imgmatch, - config_file=config_file, - reference_image=os.path.basename(reference_image), - science_image=os.path.basename(science_image), - target_ra=target['ra'], - target_dec=target['decl'], - match=match_threshold, - kernel_radius=kernel_radius, - pixel_scale=pixel_scale, - mscale=mscale, - order=1 - ) - logger.info("Executing command: %s", cmd) - - # Run the command in a subprocess: - cmd = shlex.split(cmd) - proc = subprocess.Popen(cmd, - cwd=tmpdir, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - universal_newlines=True) - stdout_data, stderr_data = proc.communicate() - returncode = proc.returncode - proc.kill() # Cleanup - Is this really needed? - - # Check the outputs from the subprocess: - logger.info("Return code: %d", returncode) - logger.info("STDOUT:\n%s", stdout_data.strip()) - if stderr_data.strip() != '': - logger.error("STDERR:\n%s", stderr_data.strip()) - if returncode < 0: - raise Exception("ImageMatch failed. Processed killed by OS with returncode %d." % returncode) - elif 'Failed object match... giving up.' in stdout_data: - #raise Exception("ImageMatch giving up matching objects") - continue - elif returncode > 0: - raise Exception("ImageMatch failed.") - - # Load the resulting difference image into memory: - diffimg_name = re.sub(r'\.fits(\.gz|\.bz2)?$', r'diff.fits\1', os.path.basename(science_image)) - diffimg_path = os.path.join(tmpdir, diffimg_name) - if not os.path.isfile(diffimg_path): - raise FileNotFoundError(diffimg_path) - - break - - else: - raise Exception("ImageMatch could not create difference image.") - - with fits.open(diffimg_path, mode='readonly') as hdu: - diffimg = np.asarray(hdu[0].data) - - return diffimg + """ + Run ImageMatch on a datafile. + + Parameters: + datafile (dict): Data file to run ImageMatch on. + target (:class:`astropy.table.Table`, optional): Target informaton. + + .. codeauthor:: Rasmus Handberg + """ + + logger = logging.getLogger(__name__) + + if datafile.get('template') is None: + raise ValueError("DATAFILE input does not specify a template to use.") + + # Extract paths to science and reference images: + reference_image = os.path.join(datafile['archive_path'], datafile['template']['path']) + science_image = os.path.join(datafile['archive_path'], datafile['path']) + + # If the target was not provided in the function call, + # use the API to get the target information: + if target is None: + catalog = api.get_catalog(datafile['targetid'], output='table') + target = catalog['target'][0] + + # Find the path to where the ImageMatch program is installed. + # This is to avoid problems with it not being on the users PATH + # and if the user is using some other version of the python executable. + # TODO: There must be a better way of doing this! + # imgmatch = os.path.join(get_setuptools_script_dir(), 'ImageMatch') + if os.name == "nt": + out = subprocess.check_output(["where", "ImageMatch"], universal_newlines=True) + imgmatch = out.strip() + else: + out = subprocess.check_output(["whereis", "ImageMatch"], universal_newlines=True) + out = re.match('ImageMatch: (.+)', out.strip()) + imgmatch = out.group(1) + + if not os.path.isfile(imgmatch): + raise FileNotFoundError("ImageMatch not found") + + # Find the ImageMatch config file to use based on the site of the observations: + __dir__ = os.path.dirname(os.path.abspath(__file__)) + if datafile['site'] in (1, 3, 4, 6): + config_file = os.path.join(__dir__, 'imagematch', 'imagematch_lcogt.cfg') + elif datafile['site'] == 2: + config_file = os.path.join(__dir__, 'imagematch', 'imagematch_hawki.cfg') + elif datafile['site'] == 5: + config_file = os.path.join(__dir__, 'imagematch', 'imagematch_alfosc.cfg') + else: + config_file = os.path.join(__dir__, 'imagematch', 'imagematch_default.cfg') + if not os.path.isfile(config_file): + raise FileNotFoundError(config_file) + + if pixel_scale is None: + if datafile['site'] in (1, 3, 4, 6): + # LCOGT provides the pixel scale directly in the header + pixel_scale = 'PIXSCALE' + else: + image = load_image(science_image) + pixel_area = proj_plane_pixel_area(image.wcs) + pixel_scale = np.sqrt(pixel_area) * 3600 # arcsec/pixel + logger.info("Calculated science image pixel scale: %f", pixel_scale) + + if datafile['template']['site'] in (1, 3, 4, 6): + # LCOGT provides the pixel scale directly in the header + mscale = 'PIXSCALE' + else: + template = load_image(reference_image) + template_pixel_area = proj_plane_pixel_area(template.wcs.celestial) + mscale = np.sqrt(template_pixel_area) * 3600 # arcsec/pixel + logger.info("Calculated template pixel scale: %f", mscale) + + # Scale kernel radius with FWHM: + if fwhm is None: + kernel_radius = 9 + else: + kernel_radius = max(9, int(np.ceil(1.5 * fwhm))) + if kernel_radius % 2 == 0: + kernel_radius += 1 + + # We will work in a temporary directory, since ImageMatch produces + # a lot of extra output files that we don't want to have lying around + # after it completes + with tempfile.TemporaryDirectory() as tmpdir: + + # Copy the science and reference image to the temp dir: + shutil.copy(reference_image, tmpdir) + shutil.copy(science_image, tmpdir) + + # Construct the command to run ImageMatch: + for match_threshold in (3.0, 5.0, 7.0, 10.0): + cmd = '"{python:s}" "{imgmatch:s}" -cfg "{config_file:s}" -snx {target_ra:.10f}d -sny {target_dec:.10f}d -p {kernel_radius:d} -o {order:d} -s {match:f} -scale {pixel_scale:} -mscale {mscale:} -m "{reference_image:s}" "{science_image:s}"'.format( + python=sys.executable, imgmatch=imgmatch, config_file=config_file, + reference_image=os.path.basename(reference_image), science_image=os.path.basename(science_image), + target_ra=target['ra'], target_dec=target['decl'], match=match_threshold, kernel_radius=kernel_radius, + pixel_scale=pixel_scale, mscale=mscale, order=1) + logger.info("Executing command: %s", cmd) + + # Run the command in a subprocess: + cmd = shlex.split(cmd) + proc = subprocess.Popen(cmd, cwd=tmpdir, stdout=subprocess.PIPE, stderr=subprocess.PIPE, + universal_newlines=True) + stdout_data, stderr_data = proc.communicate() + returncode = proc.returncode + proc.kill() # Cleanup - Is this really needed? + + # Check the outputs from the subprocess: + logger.info("Return code: %d", returncode) + logger.info("STDOUT:\n%s", stdout_data.strip()) + if stderr_data.strip() != '': + logger.error("STDERR:\n%s", stderr_data.strip()) + if returncode < 0: + raise Exception("ImageMatch failed. Processed killed by OS with returncode %d." % returncode) + elif 'Failed object match... giving up.' in stdout_data: + # raise Exception("ImageMatch giving up matching objects") + continue + elif returncode > 0: + raise Exception("ImageMatch failed.") + + # Load the resulting difference image into memory: + diffimg_name = re.sub(r'\.fits(\.gz|\.bz2)?$', r'diff.fits\1', os.path.basename(science_image)) + diffimg_path = os.path.join(tmpdir, diffimg_name) + if not os.path.isfile(diffimg_path): + raise FileNotFoundError(diffimg_path) + + break + + else: + raise Exception("ImageMatch could not create difference image.") + + with fits.open(diffimg_path, mode='readonly') as hdu: + diffimg = np.asarray(hdu[0].data) + + return diffimg diff --git a/flows/tns.py b/flows/tns.py index df3f190..9b4811e 100644 --- a/flows/tns.py +++ b/flows/tns.py @@ -1,287 +1,236 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- """ TNS API FUNCTIONS -Pre-provided helper functions for the TNS API +Pre-provided helper functions for the TNS API, type annotations added Obtained from https://wis-tns.weizmann.ac.il/content/tns-getting-started - -.. codeauthor:: Emir Karamehmetoglu -.. codeauthor:: Rasmus Handberg """ - +from __future__ import annotations import logging from astropy.table import Table import astropy.units as u +from astropy.coordinates import SkyCoord import requests import json import datetime -from .config import load_config +from tendrils.utils import load_config +from typing import Optional, Union url_tns_api = 'https://www.wis-tns.org/api/get' url_tns_search = 'https://www.wis-tns.org/search' +DateType = Union[datetime.datetime, str] + -#-------------------------------------------------------------------------------------------------- class TNSConfigError(RuntimeError): - pass - -#-------------------------------------------------------------------------------------------------- -def _load_tns_config(): - - logger = logging.getLogger(__name__) - - config = load_config() - api_key = config.get('TNS', 'api_key', fallback=None) - if api_key is None: - raise TNSConfigError("No TNS API-KEY has been defined in config") - - tns_bot_id = config.getint('TNS', 'bot_id', fallback=93222) - tns_bot_name = config.get('TNS', 'bot_name', fallback='AUFLOWS_BOT') - tns_user_id = config.getint('TNS', 'user_id', fallback=None) - tns_user_name = config.get('TNS', 'user_name', fallback=None) - - if tns_user_id and tns_user_name: - logger.debug('Using TNS credentials: user=%s', tns_user_name) - user_agent = 'tns_marker{"tns_id":' + str(tns_user_id) + ',"type":"user","name":"' + tns_user_name + '"}' - elif tns_bot_id and tns_bot_name: - logger.debug('Using TNS credentials: bot=%s', tns_bot_name) - user_agent = 'tns_marker{"tns_id":' + str(tns_bot_id) + ',"type":"bot","name":"' + tns_bot_name + '"}' - else: - raise TNSConfigError("No TNS bot_id or bot_name has been defined in config") - - return { - 'api-key': api_key, - 'user-agent': user_agent - } - -#-------------------------------------------------------------------------------------------------- -def tns_search(coord=None, radius=3*u.arcsec, objname=None, internal_name=None): - """ - Cone-search TNS for object near coordinate. - - Parameters: - coord (:class:`astropy.coordinates.SkyCoord`): Central coordinate to search around. - radius (Angle, optional): Radius to search around ``coord``. - objname (str, optional): Search on object name. - internal_name (str, optional): Search on internal name. - - Returns: - dict: Dictionary with TSN response. - """ - - # API key for Bot - tnsconf = _load_tns_config() - - # change json_list to json format - json_file = { - 'radius': radius.to('arcsec').value, - 'units': 'arcsec', - 'objname': objname, - 'internal_name': internal_name - } - if coord: - json_file['ra'] = coord.icrs.ra.deg - json_file['dec'] = coord.icrs.dec.deg - - # construct the list of (key,value) pairs - headers = {'user-agent': tnsconf['user-agent']} - search_data = [ - ('api_key', (None, tnsconf['api-key'])), - ('data', (None, json.dumps(json_file))) - ] - - # search obj using request module - res = requests.post(url_tns_api + '/search', files=search_data, headers=headers) - res.raise_for_status() - parsed = res.json() - data = parsed['data'] - - if 'reply' in data: - return data['reply'] - return None - -#-------------------------------------------------------------------------------------------------- -def tns_get_obj(name): - """ - Search TNS for object by name. - - Parameters: - name (str): Object name to search for. - - Returns: - dict: Dictionary with TSN response. - """ - - # API key for Bot - tnsconf = _load_tns_config() - - # construct the list of (key,value) pairs - headers = {'user-agent': tnsconf['user-agent']} - params = {'objname': name, 'photometry': '0', 'spectra': '0'} - get_data = [ - ('api_key', (None, tnsconf['api-key'])), - ('data', (None, json.dumps(params))) - ] - - # get obj using request module - res = requests.post(url_tns_api + '/object', files=get_data, headers=headers) - res.raise_for_status() - parsed = res.json() - data = parsed['data'] - - if 'reply' in data: - reply = data['reply'] - if not reply: - return None - if 'objname' not in reply: # Bit of a cheat, but it is simple and works - return None - - reply['internal_names'] = [name.strip() for name in reply['internal_names'].split(',') if name.strip()] - return reply - return None - -#-------------------------------------------------------------------------------------------------- -def tns_getnames(months=None, date_begin=None, date_end=None, zmin=None, zmax=None, objtype=[3, 104]): - """ - Get SN names from TNS. - - Parameters: - months (int, optional): Only return objects reported within the last X months. - date_begin (date, optional): Discovery date begin. - date_end (date, optional): Discovery date end. - zmin (float, optional): Minimum redshift. - zmax (float, optional): Maximum redshift. - objtype (list, optional): Constraint object type. - Default is to query for - - 3: SN Ia - - 104: SN Ia-91T-like - - Returns: - list: List of names fulfilling search criteria. - """ - - logger = logging.getLogger(__name__) - - # Change formats of input to be ready for query: - if isinstance(date_begin, datetime.datetime): - date_begin = date_begin.date() - elif isinstance(date_begin, str): - date_begin = datetime.datetime.strptime(date_begin, '%Y-%m-%d').date() - - if isinstance(date_end, datetime.datetime): - date_end = date_end.date() - elif isinstance(date_end, str): - date_end = datetime.datetime.strptime(date_end, '%Y-%m-%d').date() - - if isinstance(objtype, (list, tuple)): - objtype = ','.join([str(o) for o in objtype]) - - # Do some sanity checks: - if date_end < date_begin: - raise ValueError("Dates are in the wrong order.") - - date_now = datetime.datetime.now(datetime.timezone.utc).date() - if months is not None and date_end is not None and date_end < date_now - datetime.timedelta(days=months*30): - logger.warning('Months limit restricts days_begin, consider increasing limit_months.') - - # API key for Bot - tnsconf = _load_tns_config() - - # Parameters for query: - params = { - 'discovered_period_value': months, # Reported Within The Last - 'discovered_period_units': 'months', - 'unclassified_at': 0, # Limit to unclasssified ATs - 'classified_sne': 1, # Limit to classified SNe - 'include_frb': 0, # Include FRBs - #'name': , - 'name_like': 0, - 'isTNS_AT': 'all', - 'public': 'all', - #'ra': - #'decl': - #'radius': - #'coords_unit': 'arcsec', - 'reporting_groupid[]': 'null', - 'groupid[]': 'null', - 'classifier_groupid[]': 'null', - 'objtype[]': objtype, - 'at_type[]': 'null', - 'date_start[date]': date_begin.isoformat(), - 'date_end[date]': date_end.isoformat(), - #'discovery_mag_min': - #'discovery_mag_max': - #'internal_name': - #'discoverer': - #'classifier': - #'spectra_count': - 'redshift_min': zmin, - 'redshift_max': zmax, - #'hostname': - #'ext_catid': - #'ra_range_min': - #'ra_range_max': - #'decl_range_min': - #'decl_range_max': - 'discovery_instrument[]': 'null', - 'classification_instrument[]': 'null', - 'associated_groups[]': 'null', - #'at_rep_remarks': - #'class_rep_remarks': - #'frb_repeat': 'all' - #'frb_repeater_of_objid': - 'frb_measured_redshift': 0, - #'frb_dm_range_min': - #'frb_dm_range_max': - #'frb_rm_range_min': - #'frb_rm_range_max': - #'frb_snr_range_min': - #'frb_snr_range_max': - #'frb_flux_range_min': - #'frb_flux_range_max': - 'num_page': 500, - 'display[redshift]': 0, - 'display[hostname]': 0, - 'display[host_redshift]': 0, - 'display[source_group_name]': 0, - 'display[classifying_source_group_name]': 0, - 'display[discovering_instrument_name]': 0, - 'display[classifing_instrument_name]': 0, - 'display[programs_name]': 0, - 'display[internal_name]': 0, - 'display[isTNS_AT]': 0, - 'display[public]': 0, - 'display[end_pop_period]': 0, - 'display[spectra_count]': 0, - 'display[discoverymag]': 0, - 'display[discmagfilter]': 0, - 'display[discoverydate]': 0, - 'display[discoverer]': 0, - 'display[remarks]': 0, - 'display[sources]': 0, - 'display[bibcode]': 0, - 'display[ext_catalogs]': 0, - 'format': 'csv' - } - - # Query TNS for names: - headers = {'user-agent': tnsconf['user-agent']} - con = requests.get(url_tns_search, params=params, headers=headers) - con.raise_for_status() - - # Parse the CSV table: - # Ensure that there is a newline in table string. - # AstroPy uses this to distinguish file-paths from pure-string inputs: - text = str(con.text) + "\n" - tab = Table.read(text, - format='ascii.csv', - guess=False, - delimiter=',', - quotechar='"', - header_start=0, - data_start=1) - - # Pull out the names only if they begin with "SN": - names_list = [name.replace(' ', '') for name in tab['Name'] if name.startswith('SN')] - names_list = sorted(names_list) - - return names_list + pass + + +def _load_tns_config() -> dict[str, str]: + logger = logging.getLogger(__name__) + + config = load_config() + api_key = config.get('TNS', 'api_key', fallback=None) + if api_key is None: + raise TNSConfigError("No TNS API-KEY has been defined in config") + + tns_bot_id = config.getint('TNS', 'bot_id', fallback=93222) + tns_bot_name = config.get('TNS', 'bot_name', fallback='AUFLOWS_BOT') + tns_user_id = config.getint('TNS', 'user_id', fallback=None) + tns_user_name = config.get('TNS', 'user_name', fallback=None) + + if tns_user_id and tns_user_name: + logger.debug('Using TNS credentials: user=%s', tns_user_name) + user_agent = 'tns_marker{"tns_id":' + str(tns_user_id) + ',"type":"user","name":"' + tns_user_name + '"}' + elif tns_bot_id and tns_bot_name: + logger.debug('Using TNS credentials: bot=%s', tns_bot_name) + user_agent = 'tns_marker{"tns_id":' + str(tns_bot_id) + ',"type":"bot","name":"' + tns_bot_name + '"}' + else: + raise TNSConfigError("No TNS bot_id or bot_name has been defined in config") + + return {'api-key': api_key, 'user-agent': user_agent} + + +def tns_search(coord: Optional[SkyCoord] = None, radius: u.Quantity = 3 * u.arcsec, objname: Optional[str] = None, + internal_name: Optional[str] = None) -> Optional[dict]: + """ + Cone-search TNS for object near coordinate. + + Parameters: + coord (:class:`astropy.coordinates.SkyCoord`): Central coordinate to search around. + radius (Angle, optional): Radius to search around ``coord``. + objname (str, optional): Search on object name. + internal_name (str, optional): Search on internal name. + + Returns: + dict: Dictionary with TSN response. + """ + + # API key for Bot + tnsconf = _load_tns_config() + + # change json_list to json format + json_file = {'radius': radius.to('arcsec').value, 'units': 'arcsec', 'objname': objname, + 'internal_name': internal_name} + if coord: + json_file['ra'] = coord.icrs.ra.deg + json_file['dec'] = coord.icrs.dec.deg + + # construct the list of (key,value) pairs + headers = {'user-agent': tnsconf['user-agent']} + search_data = [('api_key', (None, tnsconf['api-key'])), ('data', (None, json.dumps(json_file)))] + + # search obj using request module + res = requests.post(url_tns_api + '/search', files=search_data, headers=headers) + res.raise_for_status() + parsed = res.json() + data = parsed['data'] + + if 'reply' in data: + return data['reply'] + return None + + +def tns_get_obj(name:str) -> Optional[dict]: + """ + Search TNS for object by name. + + Parameters: + name (str): Object name to search for. + + Returns: + dict: Dictionary with TSN response. + """ + + # API key for Bot + tnsconf = _load_tns_config() + + # construct the list of (key,value) pairs + headers = {'user-agent': tnsconf['user-agent']} + params = {'objname': name, 'photometry': '0', 'spectra': '0'} + get_data = [('api_key', (None, tnsconf['api-key'])), ('data', (None, json.dumps(params)))] + + # get obj using request module + res = requests.post(url_tns_api + '/object', files=get_data, headers=headers) + res.raise_for_status() + parsed = res.json() + data = parsed['data'] + + if 'reply' in data: + reply = data['reply'] + if not reply: + return None + if 'objname' not in reply: # Bit of a cheat, but it is simple and works + return None + + reply['internal_names'] = [name.strip() for name in reply['internal_names'].split(',') if name.strip()] + return reply + return None + + +def tns_getnames(months: Optional[int] = None, date_begin: Optional[DateType] = None, + date_end: Optional[DateType] = None, zmin: Optional[float] = None, + zmax: Optional[float] = None, objtype: tuple[int] = (3, 104)) -> list[str]: + """ + Get SN names from TNS. + + Parameters: + months (int, optional): Only return objects reported within the last X months. + date_begin (date, optional): Discovery date begin. + date_end (date, optional): Discovery date end. + zmin (float, optional): Minimum redshift. + zmax (float, optional): Maximum redshift. + objtype (list, optional): Constraint object type. + Default is to query for + - 3: SN Ia + - 104: SN Ia-91T-like + + Returns: + list: List of names fulfilling search criteria. + """ + + logger = logging.getLogger(__name__) + + # Change formats of input to be ready for query: + if isinstance(date_begin, datetime.datetime): + date_begin = date_begin.date() + elif isinstance(date_begin, str): + date_begin = datetime.datetime.strptime(date_begin, '%Y-%m-%d').date() + + if isinstance(date_end, datetime.datetime): + date_end = date_end.date() + elif isinstance(date_end, str): + date_end = datetime.datetime.strptime(date_end, '%Y-%m-%d').date() + + if isinstance(objtype, (list, tuple)): + objtype = ','.join([str(o) for o in objtype]) + + # Do some sanity checks: + if date_end < date_begin: + raise ValueError("Dates are in the wrong order.") + + date_now = datetime.datetime.now(datetime.timezone.utc).date() + if months is not None and date_end is not None and date_end < date_now - datetime.timedelta(days=months * 30): + logger.warning('Months limit restricts days_begin, consider increasing limit_months.') + + # API key for Bot + tnsconf = _load_tns_config() + + # Parameters for query: + params = {'discovered_period_value': months, # Reported Within The Last + 'discovered_period_units': 'months', 'unclassified_at': 0, # Limit to unclasssified ATs + 'classified_sne': 1, # Limit to classified SNe + 'include_frb': 0, # Include FRBs + # 'name': , + 'name_like': 0, 'isTNS_AT': 'all', 'public': 'all', # 'ra': + # 'decl': + # 'radius': + # 'coords_unit': 'arcsec', + 'reporting_groupid[]': 'null', 'groupid[]': 'null', 'classifier_groupid[]': 'null', 'objtype[]': objtype, + 'at_type[]': 'null', 'date_start[date]': date_begin.isoformat(), 'date_end[date]': date_end.isoformat(), + # 'discovery_mag_min': + # 'discovery_mag_max': + # 'internal_name': + # 'discoverer': + # 'classifier': + # 'spectra_count': + 'redshift_min': zmin, 'redshift_max': zmax, # 'hostname': + # 'ext_catid': + # 'ra_range_min': + # 'ra_range_max': + # 'decl_range_min': + # 'decl_range_max': + 'discovery_instrument[]': 'null', 'classification_instrument[]': 'null', 'associated_groups[]': 'null', + # 'at_rep_remarks': + # 'class_rep_remarks': + # 'frb_repeat': 'all' + # 'frb_repeater_of_objid': + 'frb_measured_redshift': 0, # 'frb_dm_range_min': + # 'frb_dm_range_max': + # 'frb_rm_range_min': + # 'frb_rm_range_max': + # 'frb_snr_range_min': + # 'frb_snr_range_max': + # 'frb_flux_range_min': + # 'frb_flux_range_max': + 'num_page': 500, 'display[redshift]': 0, 'display[hostname]': 0, 'display[host_redshift]': 0, + 'display[source_group_name]': 0, 'display[classifying_source_group_name]': 0, + 'display[discovering_instrument_name]': 0, 'display[classifing_instrument_name]': 0, + 'display[programs_name]': 0, 'display[internal_name]': 0, 'display[isTNS_AT]': 0, 'display[public]': 0, + 'display[end_pop_period]': 0, 'display[spectra_count]': 0, 'display[discoverymag]': 0, + 'display[discmagfilter]': 0, 'display[discoverydate]': 0, 'display[discoverer]': 0, 'display[remarks]': 0, + 'display[sources]': 0, 'display[bibcode]': 0, 'display[ext_catalogs]': 0, 'format': 'csv'} + + # Query TNS for names: + headers = {'user-agent': tnsconf['user-agent']} + con = requests.get(url_tns_search, params=params, headers=headers) + con.raise_for_status() + + # Parse the CSV table: + # Ensure that there is a newline in table string. + # AstroPy uses this to distinguish file-paths from pure-string inputs: + text = str(con.text) + "\n" + tab = Table.read(text, format='ascii.csv', guess=False, delimiter=',', quotechar='"', header_start=0, data_start=1) + + # Pull out the names only if they begin with "SN": + names_list = [name.replace(' ', '') for name in tab['Name'] if name.startswith('SN')] + names_list = sorted(names_list) + + return names_list diff --git a/flows/utilities.py b/flows/utilities.py index e65054f..9e956ba 100644 --- a/flows/utilities.py +++ b/flows/utilities.py @@ -8,19 +8,20 @@ import hashlib -#-------------------------------------------------------------------------------------------------- + +# -------------------------------------------------------------------------------------------------- def get_filehash(fname): - """Calculate SHA1-hash of file.""" - buf = 65536 - s = hashlib.sha1() - with open(fname, 'rb') as fid: - while True: - data = fid.read(buf) - if not data: - break - s.update(data) + """Calculate SHA1-hash of file.""" + buf = 65536 + s = hashlib.sha1() + with open(fname, 'rb') as fid: + while True: + data = fid.read(buf) + if not data: + break + s.update(data) - sha1sum = s.hexdigest().lower() - if len(sha1sum) != 40: - raise Exception("Invalid file hash") - return sha1sum + sha1sum = s.hexdigest().lower() + if len(sha1sum) != 40: + raise Exception("Invalid file hash") + return sha1sum diff --git a/flows/version.py b/flows/version.py index f746921..bd28768 100644 --- a/flows/version.py +++ b/flows/version.py @@ -30,128 +30,130 @@ # Find the "git" command to run depending on the OS: GIT_COMMAND = "git" if name == "nt": - def find_git_on_windows(): - """find the path to the git executable on windows""" - # first see if git is in the path - try: - check_output(["where", "/Q", "git"]) - # if this command succeeded, git is in the path - return "git" - # catch the exception thrown if git was not found - except CalledProcessError: - pass - # There are several locations git.exe may be hiding - possible_locations = [] - # look in program files for msysgit - if "PROGRAMFILES(X86)" in environ: - possible_locations.append("%s/Git/cmd/git.exe" % environ["PROGRAMFILES(X86)"]) - if "PROGRAMFILES" in environ: - possible_locations.append("%s/Git/cmd/git.exe" % environ["PROGRAMFILES"]) - # look for the github version of git - if "LOCALAPPDATA" in environ: - github_dir = "%s/GitHub" % environ["LOCALAPPDATA"] - if path.isdir(github_dir): - for subdir in listdir(github_dir): - if not subdir.startswith("PortableGit"): - continue - possible_locations.append("%s/%s/bin/git.exe" % (github_dir, subdir)) - for possible_location in possible_locations: - if path.isfile(possible_location): - return possible_location - # git was not found - return "git" - - GIT_COMMAND = find_git_on_windows() + def find_git_on_windows(): + """find the path to the git executable on windows""" + # first see if git is in the path + try: + check_output(["where", "/Q", "git"]) + # if this command succeeded, git is in the path + return "git" + # catch the exception thrown if git was not found + except CalledProcessError: + pass + # There are several locations git.exe may be hiding + possible_locations = [] + # look in program files for msysgit + if "PROGRAMFILES(X86)" in environ: + possible_locations.append("%s/Git/cmd/git.exe" % environ["PROGRAMFILES(X86)"]) + if "PROGRAMFILES" in environ: + possible_locations.append("%s/Git/cmd/git.exe" % environ["PROGRAMFILES"]) + # look for the github version of git + if "LOCALAPPDATA" in environ: + github_dir = "%s/GitHub" % environ["LOCALAPPDATA"] + if path.isdir(github_dir): + for subdir in listdir(github_dir): + if not subdir.startswith("PortableGit"): + continue + possible_locations.append("%s/%s/bin/git.exe" % (github_dir, subdir)) + for possible_location in possible_locations: + if path.isfile(possible_location): + return possible_location + # git was not found + return "git" + + + GIT_COMMAND = find_git_on_windows() def call_git_describe(abbrev=7): - """return the string output of git desribe""" - try: - with open(devnull, "w") as fnull: - arguments = [GIT_COMMAND, "describe", "--tags", "--abbrev=%d" % abbrev] - return check_output(arguments, cwd=CURRENT_DIRECTORY, - stderr=fnull).decode("ascii").strip() - except (OSError, CalledProcessError): - return None + """return the string output of git desribe""" + try: + with open(devnull, "w") as fnull: + arguments = [GIT_COMMAND, "describe", "--tags", "--abbrev=%d" % abbrev] + return check_output(arguments, cwd=CURRENT_DIRECTORY, stderr=fnull).decode("ascii").strip() + except (OSError, CalledProcessError): + return None + def call_git_getbranch(): - try: - with open(devnull, "w") as fnull: - arguments = [GIT_COMMAND, "symbolic-ref", "--short", "HEAD"] - return check_output(arguments, cwd=CURRENT_DIRECTORY, - stderr=fnull).decode("ascii").strip() - except (OSError, CalledProcessError): - return None + try: + with open(devnull, "w") as fnull: + arguments = [GIT_COMMAND, "symbolic-ref", "--short", "HEAD"] + return check_output(arguments, cwd=CURRENT_DIRECTORY, stderr=fnull).decode("ascii").strip() + except (OSError, CalledProcessError): + return None + def format_git_describe(git_str, pep440=False): - """format the result of calling 'git describe' as a python version""" - if git_str is None: - return None - if "-" not in git_str: # currently at a tag - return git_str - else: - # formatted as version-N-githash - # want to convert to version.postN-githash - git_str = git_str.replace("-", ".post", 1) - if pep440: # does not allow git hash afterwards - return git_str.split("-")[0] - else: - return git_str.replace("-g", "+git") + """format the result of calling 'git describe' as a python version""" + if git_str is None: + return None + if "-" not in git_str: # currently at a tag + return git_str + else: + # formatted as version-N-githash + # want to convert to version.postN-githash + git_str = git_str.replace("-", ".post", 1) + if pep440: # does not allow git hash afterwards + return git_str.split("-")[0] + else: + return git_str.replace("-g", "+git") + def read_release_version(): - """Read version information from VERSION file""" - try: - with open(VERSION_FILE, "r") as infile: - version = str(infile.read().strip()) - if len(version) == 0: - version = None - return version - except IOError: - return None + """Read version information from VERSION file""" + try: + with open(VERSION_FILE, "r") as infile: + version = str(infile.read().strip()) + if len(version) == 0: + version = None + return version + except IOError: + return None def update_release_version(): - """Update VERSION file""" - version = get_version(pep440=True) - with open(VERSION_FILE, "w") as outfile: - outfile.write(version) + """Update VERSION file""" + version = get_version(pep440=True) + with open(VERSION_FILE, "w") as outfile: + outfile.write(version) def get_version(pep440=False, include_branch=True): - """ - Tracks the version number. + """ + Tracks the version number. - The file VERSION holds the version information. If this is not a git - repository, then it is reasonable to assume that the version is not - being incremented and the version returned will be the release version as - read from the file. + The file VERSION holds the version information. If this is not a git + repository, then it is reasonable to assume that the version is not + being incremented and the version returned will be the release version as + read from the file. - However, if the script is located within an active git repository, - git-describe is used to get the version information. + However, if the script is located within an active git repository, + git-describe is used to get the version information. - The file VERSION will need to be changed by manually. This should be done - before running git tag (set to the same as the version in the tag). + The file VERSION will need to be changed by manually. This should be done + before running git tag (set to the same as the version in the tag). - Parameters: - pep440 (bool): When True, this function returns a version string suitable for - a release as defined by PEP 440. When False, the githash (if - available) will be appended to the version string. + Parameters: + pep440 (bool): When True, this function returns a version string suitable for + a release as defined by PEP 440. When False, the githash (if + available) will be appended to the version string. - Returns: - string: Version sting. - """ + Returns: + string: Version sting. + """ - git_version = format_git_describe(call_git_describe(), pep440=pep440) - if git_version is None: # not a git repository - return read_release_version() + git_version = format_git_describe(call_git_describe(), pep440=pep440) + if git_version is None: # not a git repository + return read_release_version() - if include_branch: - git_branch = call_git_getbranch() - if git_branch is not None: - git_version = git_branch + '-' + git_version + if include_branch: + git_branch = call_git_getbranch() + if git_branch is not None: + git_version = git_branch + '-' + git_version - return git_version + return git_version if __name__ == "__main__": - print(get_version()) + print(get_version()) diff --git a/flows/visibility.py b/flows/visibility.py index 667a219..2736b6b 100644 --- a/flows/visibility.py +++ b/flows/visibility.py @@ -16,118 +16,121 @@ from datetime import datetime from astropy.coordinates import SkyCoord, AltAz, get_sun, get_moon from astropy.visualization import quantity_support -from . import api +from tendrils import api -#-------------------------------------------------------------------------------------------------- + +# -------------------------------------------------------------------------------------------------- def visibility(target, siteid=None, date=None, output=None, overwrite=True): - """ - Create visibility plot. - - Parameters: - target (str or int): - siteid (int): Identifier of site. - date (datetime or str, optional): Date for which to create visibility plot. - Default it to use the current date. - output (str, optional): Path to file or directory where to place the plot. - If not given, the plot will be created in memory, and can be shown on screen. - overwrite (bool, optional): Should existing file specified in ``output`` be overwritten? - Default is to overwrite an existing file. - - .. codeauthor:: Rasmus Handberg - """ - - logger = logging.getLogger(__name__) - - if date is None: - date = datetime.utcnow() - elif isinstance(date, str): - date = datetime.strptime(date, '%Y-%m-%d') - - tgt = api.get_target(target) - - # Coordinates of object: - obj = SkyCoord(ra=tgt['ra'], dec=tgt['decl'], unit='deg', frame='icrs') - - if siteid is None: - sites = api.get_all_sites() - else: - sites = [api.get_site(siteid)] - - plotpaths = [] - for site in sites: - # If we are saving plot to file, determine the path to save to - # and check that it doesn't already exist: - if output: - if os.path.isdir(output): - plotpath = os.path.join(output, "visibility_%s_%s_site%02d.png" % ( - tgt['target_name'], - date.strftime('%Y%m%d'), - site['siteid'])) - else: - plotpath = output - logger.debug("Will save visibility plot to '%s'", plotpath) - - # If we are not overwriting and - if not overwrite and os.path.exists(plotpath): - logger.info("File already exists: %s", plotpath) - continue - - # Observatory: - observatory = site['EarthLocation'] - utcoffset = (site['longitude']*u.deg/(360*u.deg)) * 24*u.hour - - # Create timestamps to calculate for: - midnight = Time(date.strftime('%Y-%m-%d') + ' 00:00:00', scale='utc') - utcoffset - delta_midnight = np.linspace(-12, 12, 1000)*u.hour - times = midnight + delta_midnight - - # AltAz frame: - AltAzFrame = AltAz(obstime=times, location=observatory) - - # Object: - altaz_obj = obj.transform_to(AltAzFrame) - - # The Sun and Moon: - altaz_sun = get_sun(times).transform_to(AltAzFrame) - altaz_moon = get_moon(times).transform_to(AltAzFrame) - - sundown_astro = (altaz_sun.alt < -6*u.deg) - if np.any(sundown_astro): - min_time = np.min(times[sundown_astro]) - 2*u.hour - max_time = np.max(times[sundown_astro]) + 2*u.hour - else: - min_time = times[0] - max_time = times[-1] - - quantity_support() - fig, ax = plt.subplots(1, 1, figsize=(15,9), squeeze=True) - plt.grid(ls=':', lw=0.5) - ax.plot(times.datetime, altaz_sun.alt, color='y', label='Sun') - ax.plot(times.datetime, altaz_moon.alt, color=[0.75]*3, ls='--', label='Moon') - objsc = ax.scatter(times.datetime, altaz_obj.alt, c=altaz_obj.az, label=tgt['target_name'], lw=0, s=8, cmap='twilight') - ax.fill_between(times.datetime, 0*u.deg, 90*u.deg, altaz_sun.alt < -0*u.deg, color='0.5', zorder=0) # , label='Night' - ax.fill_between(times.datetime, 0*u.deg, 90*u.deg, altaz_sun.alt < -18*u.deg, color='k', zorder=0) # , label='Astronomical Night' - - plt.colorbar(objsc, ax=ax, pad=0.01).set_label('Azimuth [deg]') - ax.legend(loc='upper left') - ax.minorticks_on() - ax.set_xlim(min_time.datetime, max_time.datetime) - ax.set_ylim(0*u.deg, 90*u.deg) - ax.set_title("%s - %s - %s" % (str(tgt['target_name']), date.strftime('%Y-%m-%d'), site['sitename']), fontsize=14) - plt.xlabel('Time [UTC]', fontsize=14) - plt.ylabel('Altitude [deg]', fontsize=16) - fig.autofmt_xdate() - - formatter = DateFormatter('%d/%m %H:%M') - ax.xaxis.set_major_formatter(formatter) - - if output: - fig.savefig(plotpath, bbox_inches='tight', transparent=True) - plt.close(fig) - plotpaths.append(plotpath) - - if output: - return plotpaths - - plt.show() - return ax + """ + Create visibility plot. + + Parameters: + target (str or int): + siteid (int): Identifier of site. + date (datetime or str, optional): Date for which to create visibility plot. + Default it to use the current date. + output (str, optional): Path to file or directory where to place the plot. + If not given, the plot will be created in memory, and can be shown on screen. + overwrite (bool, optional): Should existing file specified in ``output`` be overwritten? + Default is to overwrite an existing file. + + .. codeauthor:: Rasmus Handberg + """ + + logger = logging.getLogger(__name__) + + if date is None: + date = datetime.utcnow() + elif isinstance(date, str): + date = datetime.strptime(date, '%Y-%m-%d') + + tgt = api.get_target(target) + + # Coordinates of object: + obj = SkyCoord(ra=tgt['ra'], dec=tgt['decl'], unit='deg', frame='icrs') + + if siteid is None: + sites = api.get_all_sites() + else: + sites = [api.get_site(siteid)] + + plotpaths = [] + for site in sites: + # If we are saving plot to file, determine the path to save to + # and check that it doesn't already exist: + if output: + if os.path.isdir(output): + plotpath = os.path.join(output, "visibility_%s_%s_site%02d.png" % ( + tgt['target_name'], date.strftime('%Y%m%d'), site['siteid'])) + else: + plotpath = output + logger.debug("Will save visibility plot to '%s'", plotpath) + + # If we are not overwriting and + if not overwrite and os.path.exists(plotpath): + logger.info("File already exists: %s", plotpath) + continue + + # Observatory: + observatory = site['EarthLocation'] + utcoffset = (site['longitude'] * u.deg / (360 * u.deg)) * 24 * u.hour + + # Create timestamps to calculate for: + midnight = Time(date.strftime('%Y-%m-%d') + ' 00:00:00', scale='utc') - utcoffset + delta_midnight = np.linspace(-12, 12, 1000) * u.hour + times = midnight + delta_midnight + + # AltAz frame: + AltAzFrame = AltAz(obstime=times, location=observatory) + + # Object: + altaz_obj = obj.transform_to(AltAzFrame) + + # The Sun and Moon: + altaz_sun = get_sun(times).transform_to(AltAzFrame) + altaz_moon = get_moon(times).transform_to(AltAzFrame) + + sundown_astro = (altaz_sun.alt < -6 * u.deg) + if np.any(sundown_astro): + min_time = np.min(times[sundown_astro]) - 2 * u.hour + max_time = np.max(times[sundown_astro]) + 2 * u.hour + else: + min_time = times[0] + max_time = times[-1] + + quantity_support() + fig, ax = plt.subplots(1, 1, figsize=(15, 9), squeeze=True) + plt.grid(ls=':', lw=0.5) + ax.plot(times.datetime, altaz_sun.alt, color='y', label='Sun') + ax.plot(times.datetime, altaz_moon.alt, color=[0.75] * 3, ls='--', label='Moon') + objsc = ax.scatter(times.datetime, altaz_obj.alt, c=altaz_obj.az, label=tgt['target_name'], lw=0, s=8, + cmap='twilight') + ax.fill_between(times.datetime, 0 * u.deg, 90 * u.deg, altaz_sun.alt < -0 * u.deg, color='0.5', + zorder=0) # , label='Night' + ax.fill_between(times.datetime, 0 * u.deg, 90 * u.deg, altaz_sun.alt < -18 * u.deg, color='k', + zorder=0) # , label='Astronomical Night' + + plt.colorbar(objsc, ax=ax, pad=0.01).set_label('Azimuth [deg]') + ax.legend(loc='upper left') + ax.minorticks_on() + ax.set_xlim(min_time.datetime, max_time.datetime) + ax.set_ylim(0 * u.deg, 90 * u.deg) + ax.set_title("%s - %s - %s" % (str(tgt['target_name']), date.strftime('%Y-%m-%d'), site['sitename']), + fontsize=14) + plt.xlabel('Time [UTC]', fontsize=14) + plt.ylabel('Altitude [deg]', fontsize=16) + fig.autofmt_xdate() + + formatter = DateFormatter('%d/%m %H:%M') + ax.xaxis.set_major_formatter(formatter) + + if output: + fig.savefig(plotpath, bbox_inches='tight', transparent=True) + plt.close(fig) + plotpaths.append(plotpath) + + if output: + return plotpaths + + plt.show() + return ax diff --git a/flows/zeropoint.py b/flows/zeropoint.py index 39669da..486f7b2 100644 --- a/flows/zeropoint.py +++ b/flows/zeropoint.py @@ -15,54 +15,53 @@ from scipy.special import erfcinv -#Calculate sigma for sigma clipping using Chauvenet +# Calculate sigma for sigma clipping using Chauvenet def sigma_from_Chauvenet(Nsamples): - '''Calculate sigma according to the Cheuvenet criterion''' - return erfcinv(1./(2*Nsamples)) * (2.)**(1/2) + '''Calculate sigma according to the Cheuvenet criterion''' + return erfcinv(1. / (2 * Nsamples)) * (2.) ** (1 / 2) -def bootstrap_outlier(x,y,yerr, n=500, model='None',fitter='None', - outlier='None', outlier_kwargs={'sigma':3}, summary='median', error='bootstrap', - parnames=['intercept'], return_vals=True): - '''x = catalog mag, y = instrumental mag, yerr = instrumental error - summary = function for summary statistic, np.nanmedian by default. - model = Linear1D - fitter = LinearLSQFitter - outlier = 'sigma_clip' - outlier_kwargs, default sigma = 3 - return_vals = False will return dictionary - Performs bootstrap with replacement and returns model. - ''' - summary = np.nanmedian if summary == 'median' else summary - error = np.nanstd if error == 'bootstrap' else error +def bootstrap_outlier(x, y, yerr, n=500, model='None', fitter='None', outlier='None', outlier_kwargs={'sigma': 3}, + summary='median', error='bootstrap', parnames=['intercept'], return_vals=True): + '''x = catalog mag, y = instrumental mag, yerr = instrumental error + summary = function for summary statistic, np.nanmedian by default. + model = Linear1D + fitter = LinearLSQFitter + outlier = 'sigma_clip' + outlier_kwargs, default sigma = 3 + return_vals = False will return dictionary + Performs bootstrap with replacement and returns model. + ''' + summary = np.nanmedian if summary == 'median' else summary + error = np.nanstd if error == 'bootstrap' else error - #Create index for bootstrapping - ind = np.arange(len(x)) + # Create index for bootstrapping + ind = np.arange(len(x)) - #Bootstrap indexes with replacement using astropy - bootstraps = bootstrap(ind,bootnum=n) - bootstraps.sort() # sort increasing. - bootinds = bootstraps.astype(int) + # Bootstrap indexes with replacement using astropy + bootstraps = bootstrap(ind, bootnum=n) + bootstraps.sort() # sort increasing. + bootinds = bootstraps.astype(int) - #Prepare fitter - fitter_instance = fitting.FittingWithOutlierRemoval(fitter(),outlier, **outlier_kwargs) - #Fit each bootstrap with model and fitter using outlier rejection at each step. - #Then obtain summary statistic for each parameter in parnames - pars = {} - out = {} - for parname in parnames: - pars[parname] = np.ones(len(bootinds), dtype=np.float64) - for i,bs in enumerate(bootinds): - #w = np.ones(len(x[bs]), dtype=np.float64) if yerr=='None' else (1.0/yerr[bs])**2 - w = (1.0/yerr[bs])**2 - best_fit, sigma_clipped = fitter_instance(model, x[bs], y[bs], weights=w) - #obtain parameters of interest - for parname in parnames: - pars[parname][i] = best_fit.parameters[np.array(best_fit.param_names) == parname][0] - if return_vals: - return [summary(pars[par]) for par in pars] + # Prepare fitter + fitter_instance = fitting.FittingWithOutlierRemoval(fitter(), outlier, **outlier_kwargs) + # Fit each bootstrap with model and fitter using outlier rejection at each step. + # Then obtain summary statistic for each parameter in parnames + pars = {} + out = {} + for parname in parnames: + pars[parname] = np.ones(len(bootinds), dtype=np.float64) + for i, bs in enumerate(bootinds): + # w = np.ones(len(x[bs]), dtype=np.float64) if yerr=='None' else (1.0/yerr[bs])**2 + w = (1.0 / yerr[bs]) ** 2 + best_fit, sigma_clipped = fitter_instance(model, x[bs], y[bs], weights=w) + # obtain parameters of interest + for parname in parnames: + pars[parname][i] = best_fit.parameters[np.array(best_fit.param_names) == parname][0] + if return_vals: + return [summary(pars[par]) for par in pars] - for parname in parnames: - out[parname] = summary(pars[parname]) - out[parname+'_error'] = error(pars[parname]) - return out + for parname in parnames: + out[parname] = summary(pars[parname]) + out[parname + '_error'] = error(pars[parname]) + return out diff --git a/flows/ztf.py b/flows/ztf.py index dc74d8e..75ad66c 100644 --- a/flows/ztf.py +++ b/flows/ztf.py @@ -1,11 +1,6 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- """ Query ZTF target information using ALeRCE API. https://alerceapi.readthedocs.io/ - -.. codeauthor:: Emir Karamehmetoglu -.. codeauthor:: Rasmus Handberg """ import numpy as np @@ -15,135 +10,129 @@ from astropy.time import Time import datetime import requests -from . import api - -#-------------------------------------------------------------------------------------------------- -def query_ztf_id(coo_centre, radius=3*u.arcsec, discovery_date=None): - """ - Query ALeRCE ZTF api to lookup ZTF identifier. - - In case multiple identifiers are found within the search cone, the one - closest to the centre is returned. - - Parameters: - coo_centre (:class:`astropy.coordinates.SkyCoord`): Coordinates of centre of search cone. - radius (Angle, optional): Search radius. Default 3 arcsec. - discovery_date (:class:`astropy.time.Time`, optional): Discovery date of target to - match against ZTF. The date is compared to the ZTF first timestamp and ZTF targets - are rejected if they are not within 15 days prior to the discovery date - and 90 days after. - - Returns: - str: ZTF identifier. - - .. codeauthor:: Emir Karamehmetoglu - .. codeauthor:: Rasmus Handberg - """ - - if isinstance(radius, (float, int)): - radius *= u.deg - - # Make json query for Alerce query API - query = { - 'ra': coo_centre.ra.deg, - 'dec': coo_centre.dec.deg, - 'radius': Angle(radius).arcsec, - 'page_size': 20, - 'count': True - } - - # Run http POST json query to alerce following their API - res = requests.get('https://api.alerce.online/ztf/v1/objects', params=query) - res.raise_for_status() - jsn = res.json() - - # If nothing was found, return None: - if jsn['total'] == 0: - return None - - # Start by removing anything marked as likely stellar-like source: - results = jsn['items'] - results = [itm for itm in results if not itm['stellar']] - if not results: - return None - - # Constrain on the discovery date if it is provided: - if discovery_date is not None: - # Extract the time of the first ZTF timestamp and compare it with - # the discovery time: - firstmjd = Time([itm['firstmjd'] for itm in results], format='mjd', scale='utc') - tdelta = firstmjd.utc.mjd - discovery_date.utc.mjd - - # Only keep results that are within the margins: - results = [itm for k, itm in enumerate(results) if -15 <= tdelta[k] <= 90] - if not results: - return None - - # Find target closest to the centre: - coords = SkyCoord( - ra=[itm['meanra'] for itm in results], - dec=[itm['meandec'] for itm in results], - unit='deg', - frame='icrs') - - indx = np.argmin(coords.separation(coo_centre)) - - return results[indx]['oid'] - -#-------------------------------------------------------------------------------------------------- +from tendrils import api + + +# -------------------------------------------------------------------------------------------------- +def query_ztf_id(coo_centre, radius=3 * u.arcsec, discovery_date=None): + """ + Query ALeRCE ZTF api to lookup ZTF identifier. + + In case multiple identifiers are found within the search cone, the one + closest to the centre is returned. + + Parameters: + coo_centre (:class:`astropy.coordinates.SkyCoord`): Coordinates of centre of search cone. + radius (Angle, optional): Search radius. Default 3 arcsec. + discovery_date (:class:`astropy.time.Time`, optional): Discovery date of target to + match against ZTF. The date is compared to the ZTF first timestamp and ZTF targets + are rejected if they are not within 15 days prior to the discovery date + and 90 days after. + + Returns: + str: ZTF identifier. + + .. codeauthor:: Emir Karamehmetoglu + .. codeauthor:: Rasmus Handberg + """ + + if isinstance(radius, (float, int)): + radius *= u.deg + + # Make json query for Alerce query API + query = {'ra': coo_centre.ra.deg, 'dec': coo_centre.dec.deg, 'radius': Angle(radius).arcsec, 'page_size': 20, + 'count': True} + + # Run http POST json query to alerce following their API + res = requests.get('https://api.alerce.online/ztf/v1/objects', params=query) + res.raise_for_status() + jsn = res.json() + + # If nothing was found, return None: + if jsn['total'] == 0: + return None + + # Start by removing anything marked as likely stellar-like source: + results = jsn['items'] + results = [itm for itm in results if not itm['stellar']] + if not results: + return None + + # Constrain on the discovery date if it is provided: + if discovery_date is not None: + # Extract the time of the first ZTF timestamp and compare it with + # the discovery time: + firstmjd = Time([itm['firstmjd'] for itm in results], format='mjd', scale='utc') + tdelta = firstmjd.utc.mjd - discovery_date.utc.mjd + + # Only keep results that are within the margins: + results = [itm for k, itm in enumerate(results) if -15 <= tdelta[k] <= 90] + if not results: + return None + + # Find target closest to the centre: + coords = SkyCoord(ra=[itm['meanra'] for itm in results], dec=[itm['meandec'] for itm in results], unit='deg', + frame='icrs') + + indx = np.argmin(coords.separation(coo_centre)) + + return results[indx]['oid'] + + +# -------------------------------------------------------------------------------------------------- def download_ztf_photometry(targetid): - """ - Download ZTF photometry from ALERCE API. - - Parameters: - targetid (int): Target identifier. - - Returns: - :class:`astropy.table.Table`: ZTF photometry table. - - .. codeauthor:: Emir Karamehmetoglu - .. codeauthor:: Rasmus Handberg - """ - - # Get target info from Flows API: - tgt = api.get_target(targetid) - oid = tgt['ztf_id'] - target_name = tgt['target_name'] - if oid is None: - return None - - # Query ALERCE for detections of object based on oid - res = requests.get(f'https://api.alerce.online/ztf/v1/objects/{oid:s}/detections') - res.raise_for_status() - jsn = res.json() - - # Create Astropy table, cut out the needed columns - # and rename columns to something better for what we are doing: - tab = Table(data=jsn) - tab = tab[['fid', 'mjd', 'magpsf', 'sigmapsf']] - tab.rename_column('fid', 'photfilter') - tab.rename_column('mjd', 'time') - tab.rename_column('magpsf', 'mag') - tab.rename_column('sigmapsf', 'mag_err') - - # Remove bad values of time and magnitude: - tab['time'] = np.asarray(tab['time'], dtype='float64') - tab['mag'] = np.asarray(tab['mag'], dtype='float64') - tab['mag_err'] = np.asarray(tab['mag_err'], dtype='float64') - indx = np.isfinite(tab['time']) & np.isfinite(tab['mag']) & np.isfinite(tab['mag_err']) - tab = tab[indx] - - # Replace photometric filter numbers with keywords used in Flows: - photfilter_dict = {1: 'gp', 2: 'rp', 3: 'ip'} - tab['photfilter'] = [photfilter_dict[fid] for fid in tab['photfilter']] - - # Sort the table on photfilter and time: - tab.sort(['photfilter', 'time']) - - # Add meta information to table header: - tab.meta['target_name'] = target_name - tab.meta['targetid'] = targetid - tab.meta['ztf_id'] = oid - tab.meta['last_updated'] = datetime.datetime.now(tz=datetime.timezone.utc).isoformat() - - return tab + """ + Download ZTF photometry from ALERCE API. + + Parameters: + targetid (int): Target identifier. + + Returns: + :class:`astropy.table.Table`: ZTF photometry table. + + .. codeauthor:: Emir Karamehmetoglu + .. codeauthor:: Rasmus Handberg + """ + + # Get target info from Flows API: + tgt = api.get_target(targetid) + oid = tgt['ztf_id'] + target_name = tgt['target_name'] + if oid is None: + return None + + # Query ALERCE for detections of object based on oid + res = requests.get(f'https://api.alerce.online/ztf/v1/objects/{oid:s}/detections') + res.raise_for_status() + jsn = res.json() + + # Create Astropy table, cut out the needed columns + # and rename columns to something better for what we are doing: + tab = Table(data=jsn) + tab = tab[['fid', 'mjd', 'magpsf', 'sigmapsf']] + tab.rename_column('fid', 'photfilter') + tab.rename_column('mjd', 'time') + tab.rename_column('magpsf', 'mag') + tab.rename_column('sigmapsf', 'mag_err') + + # Remove bad values of time and magnitude: + tab['time'] = np.asarray(tab['time'], dtype='float64') + tab['mag'] = np.asarray(tab['mag'], dtype='float64') + tab['mag_err'] = np.asarray(tab['mag_err'], dtype='float64') + indx = np.isfinite(tab['time']) & np.isfinite(tab['mag']) & np.isfinite(tab['mag_err']) + tab = tab[indx] + + # Replace photometric filter numbers with keywords used in Flows: + photfilter_dict = {1: 'gp', 2: 'rp', 3: 'ip'} + tab['photfilter'] = [photfilter_dict[fid] for fid in tab['photfilter']] + + # Sort the table on photfilter and time: + tab.sort(['photfilter', 'time']) + + # Add meta information to table header: + tab.meta['target_name'] = target_name + tab.meta['targetid'] = targetid + tab.meta['ztf_id'] = oid + tab.meta['last_updated'] = datetime.datetime.now(tz=datetime.timezone.utc).isoformat() + + return tab diff --git a/notes/disk_covering_problem.py b/notes/disk_covering_problem.py index 02e30dd..6824d49 100644 --- a/notes/disk_covering_problem.py +++ b/notes/disk_covering_problem.py @@ -12,34 +12,31 @@ from matplotlib.patches import Circle import astropy.units as u -#-------------------------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------------------------- if __name__ == '__main__': - coo_centre = SkyCoord(ra=0, dec=0, unit='deg', frame='icrs') + coo_centre = SkyCoord(ra=0, dec=0, unit='deg', frame='icrs') - radius = 24.0/60.0 + radius = 24.0 / 60.0 - #aframe = SkyOffsetFrame(origin=coo_centre) - #c = coo_centre.transform_to(aframe) - #print(c) + # aframe = SkyOffsetFrame(origin=coo_centre) + # c = coo_centre.transform_to(aframe) + # print(c) - fig, ax = plt.subplots() - ax.plot(coo_centre.ra.deg, coo_centre.dec.deg, 'rx') + fig, ax = plt.subplots() + ax.plot(coo_centre.ra.deg, coo_centre.dec.deg, 'rx') - ax.add_artist(Circle([coo_centre.ra.deg, coo_centre.dec.deg], radius=radius, ec='r', fc=None, fill=False)) - ax.add_artist(Circle([coo_centre.ra.deg, coo_centre.dec.deg], radius=0.5*radius, ec='b', fc=None, fill=False)) + ax.add_artist(Circle([coo_centre.ra.deg, coo_centre.dec.deg], radius=radius, ec='r', fc=None, fill=False)) + ax.add_artist(Circle([coo_centre.ra.deg, coo_centre.dec.deg], radius=0.5 * radius, ec='b', fc=None, fill=False)) - for n in range(6): - new = SkyCoord( - ra=coo_centre.ra.deg + 0.8 * radius * np.cos(n*60*np.pi/180), - dec=coo_centre.dec.deg + 0.8 * radius * np.sin(n*60*np.pi/180), - unit='deg', frame='icrs') + for n in range(6): + new = SkyCoord(ra=coo_centre.ra.deg + 0.8 * radius * np.cos(n * 60 * np.pi / 180), + dec=coo_centre.dec.deg + 0.8 * radius * np.sin(n * 60 * np.pi / 180), unit='deg', frame='icrs') - ax.plot(new.ra.deg, new.dec.deg, 'bx') - ax.add_artist(Circle([new.ra.deg, new.dec.deg], radius=0.5*radius, ec='b', fc=None, fill=False)) + ax.plot(new.ra.deg, new.dec.deg, 'bx') + ax.add_artist(Circle([new.ra.deg, new.dec.deg], radius=0.5 * radius, ec='b', fc=None, fill=False)) - - plt.axis('equal') - #ax.set_xlim(coo_centre.ra.deg + radius * np.array([-2, 2])) - #ax.set_ylim(coo_centre.dec.deg +radius * np.array([-2, 2])) - plt.show() + plt.axis('equal') + # ax.set_xlim(coo_centre.ra.deg + radius * np.array([-2, 2])) + # ax.set_ylim(coo_centre.dec.deg +radius * np.array([-2, 2])) + plt.show() diff --git a/notes/fix_ztf_ids.py b/notes/fix_ztf_ids.py index e8df20c..d5268cb 100644 --- a/notes/fix_ztf_ids.py +++ b/notes/fix_ztf_ids.py @@ -5,29 +5,30 @@ import os.path from tqdm import tqdm from astropy.coordinates import SkyCoord + if os.path.abspath('..') not in sys.path: - sys.path.insert(0, os.path.abspath('..')) + sys.path.insert(0, os.path.abspath('..')) import flows -#-------------------------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------------------------- if __name__ == '__main__': - with flows.aadc_db.AADC_DB() as db: - for target in tqdm(flows.api.get_targets()): - if target['target_status'] == 'rejected': - continue + with flows.aadc_db.AADC_DB() as db: + for target in tqdm(flows.api.get_targets()): + if target['target_status'] == 'rejected': + continue - targetid = target['targetid'] - coord = SkyCoord(ra=target['ra'], dec=target['decl'], unit='deg', frame='icrs') - dd = target['discovery_date'] + targetid = target['targetid'] + coord = SkyCoord(ra=target['ra'], dec=target['decl'], unit='deg', frame='icrs') + dd = target['discovery_date'] - # Query for the ZTF id: - ztf_id = flows.ztf.query_ztf_id(coord, discovery_date=dd) + # Query for the ZTF id: + ztf_id = flows.ztf.query_ztf_id(coord, discovery_date=dd) - # If the ZTF id is not the same as we have currently, update it in the database: - if ztf_id != target['ztf_id']: - print(target) - print(ztf_id) - print("******* NEEDS UPDATE ******") + # If the ZTF id is not the same as we have currently, update it in the database: + if ztf_id != target['ztf_id']: + print(target) + print(ztf_id) + print("******* NEEDS UPDATE ******") - db.cursor.execute("UPDATE flows.targets SET ztf_id=%s WHERE targetid=%s;", (ztf_id, targetid)) - db.conn.commit() + db.cursor.execute("UPDATE flows.targets SET ztf_id=%s WHERE targetid=%s;", (ztf_id, targetid)) + db.conn.commit() diff --git a/notes/update_all_catalogs.py b/notes/update_all_catalogs.py index 7356441..3fcd5c0 100644 --- a/notes/update_all_catalogs.py +++ b/notes/update_all_catalogs.py @@ -6,47 +6,51 @@ import os.path import tqdm from astropy.coordinates import SkyCoord + if os.path.abspath('..') not in sys.path: - sys.path.insert(0, os.path.abspath('..')) + sys.path.insert(0, os.path.abspath('..')) import flows -#-------------------------------------------------------------------------------------------------- + +# -------------------------------------------------------------------------------------------------- class TqdmLoggingHandler(logging.Handler): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def emit(self, record): - try: - msg = self.format(record) - tqdm.tqdm.write(msg) - self.flush() - except (KeyboardInterrupt, SystemExit): # pragma: no cover - raise - except: # noqa: E722, pragma: no cover - self.handleError(record) - -#-------------------------------------------------------------------------------------------------- + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def emit(self, record): + try: + msg = self.format(record) + tqdm.tqdm.write(msg) + self.flush() + except (KeyboardInterrupt, SystemExit): # pragma: no cover + raise + except: # noqa: E722, pragma: no cover + self.handleError(record) + + +# -------------------------------------------------------------------------------------------------- if __name__ == '__main__': - # Setup logging: - formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') - console = TqdmLoggingHandler() - console.setFormatter(formatter) - logger = logging.getLogger('flows') - if not logger.hasHandlers(): - logger.addHandler(console) - logger.setLevel(logging.INFO) - - # Do it by status, just to prioritize things a bit: - for tgtstatus in ('target', 'candidate', 'rejected'): - targetids = sorted([tgt['targetid'] for tgt in flows.api.get_targets() if tgt['target_status'] == tgtstatus])[::-1] - - for targetid in tqdm.tqdm(targetids, desc=tgtstatus): - donefile = f"catalog_updates/{targetid:05d}.done" - if not os.path.exists(donefile): - try: - flows.catalogs.download_catalog(targetid, update_existing=True) - except: - logger.exception("targetid=%d", targetid) - else: - open(donefile, 'w').close() + # Setup logging: + formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') + console = TqdmLoggingHandler() + console.setFormatter(formatter) + logger = logging.getLogger('flows') + if not logger.hasHandlers(): + logger.addHandler(console) + logger.setLevel(logging.INFO) + + # Do it by status, just to prioritize things a bit: + for tgtstatus in ('target', 'candidate', 'rejected'): + targetids = sorted([tgt['targetid'] for tgt in flows.api.get_targets() if tgt['target_status'] == tgtstatus])[ + ::-1] + + for targetid in tqdm.tqdm(targetids, desc=tgtstatus): + donefile = f"catalog_updates/{targetid:05d}.done" + if not os.path.exists(donefile): + try: + flows.catalogs.download_catalog(targetid, update_existing=True) + except: + logger.exception("targetid=%d", targetid) + else: + open(donefile, 'w').close() diff --git a/requirements.txt b/requirements.txt index 136d3ac..6e5a42c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,10 +4,9 @@ flake8-tabs >= 2.3.2 flake8-builtins flake8-logging-format numpy >= 1.16 -scipy == 1.5.4 +scipy >= 1.5.4 astropy == 4.1 photutils == 1.1.0; python_version >= '3.7' -photutils == 1.0.2; python_version < '3.7' Bottleneck == 1.3.2 matplotlib == 3.3.1 mplcursors == 0.3 @@ -18,7 +17,7 @@ PyYAML psycopg2-binary jplephem vtk -scikit-image == 0.17.2 +scikit-image >= 0.17.2 tqdm pytz git+https://github.com/obscode/imagematch.git@photutils#egg=imagematch @@ -26,3 +25,4 @@ sep astroalign > 2.3 networkx astroquery >= 0.4.2 +tendrils >= 0.1.5 \ No newline at end of file diff --git a/run_catalogs.py b/run_catalogs.py index c73a2d4..5e283d1 100644 --- a/run_catalogs.py +++ b/run_catalogs.py @@ -1,47 +1,58 @@ -# -*- coding: utf-8 -*- """ - -.. codeauthor:: Rasmus Handberg +Runner to add target catalog, if catalog info is missing. Print output. """ - import argparse import logging -from flows import api, download_catalog +from tendrils import api +from flows import download_catalog + + +def parse(): + """ + # Parse command line arguments: + """ + parser = argparse.ArgumentParser(description='Run catalog.') + parser.add_argument('-d', '--debug', help='Print debug messages.', action='store_true') + parser.add_argument('-q', '--quiet', help='Only report warnings and errors.', action='store_true') + parser.add_argument('-t', '--target', type=str, help='Target to print catalog for.', nargs='?', default=None) + return parser.parse_args() + +def set_logging_level(args): + # Set logging level: + logging_level = logging.INFO + if args.quiet: + logging_level = logging.WARNING + elif args.debug: + logging_level = logging.DEBUG + return logging_level + +def main(): + # Parse command line arguments: + args = parse() + + # Setup logging: + logging_level = set_logging_level(args) + formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') + console = logging.StreamHandler() + console.setFormatter(formatter) + logger = logging.getLogger('flows') + if not logger.hasHandlers(): + logger.addHandler(console) + logger.setLevel(logging_level) + + # Get missing + for target in api.get_catalog_missing(): + logger.info("Downloading catalog for target=%s...", target) + download_catalog(target) # @TODO: refactor to Tendrils + + # download target catalog for printing + if args.target is not None: + cat = api.get_catalog(args.target) + + print(f"Target:{cat['target'].pprint_all()} " + f"\nReferences: {cat['references'].pprint_all()} " + f"\nAvoid:cat['avoid'].pprint_all()") + if __name__ == '__main__': - # Parse command line arguments: - parser = argparse.ArgumentParser(description='Run catalog.') - parser.add_argument('-d', '--debug', help='Print debug messages.', action='store_true') - parser.add_argument('-q', '--quiet', help='Only report warnings and errors.', action='store_true') - parser.add_argument('-t', '--target', type=str, help='Target to print catalog for.', nargs='?', default=None) - args = parser.parse_args() - - # Set logging level: - logging_level = logging.INFO - if args.quiet: - logging_level = logging.WARNING - elif args.debug: - logging_level = logging.DEBUG - - # Setup logging: - formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') - console = logging.StreamHandler() - console.setFormatter(formatter) - logger = logging.getLogger('flows') - if not logger.hasHandlers(): - logger.addHandler(console) - logger.setLevel(logging_level) - - for target in api.get_catalog_missing(): - logger.info("Downloading catalog for target=%s...", target) - download_catalog(target) - - if args.target is not None: - cat = api.get_catalog(args.target) - - print("Target:") - cat['target'].pprint_all() - print("\nReferences:") - cat['references'].pprint_all() - print("\nAvoid:") - cat['avoid'].pprint_all() + main() diff --git a/run_download_ztf.py b/run_download_ztf.py index ea02f55..459a859 100644 --- a/run_download_ztf.py +++ b/run_download_ztf.py @@ -1,8 +1,6 @@ #!/usr/bin/env python3 -# -*- coding: utf-8 -*- """ Download ZTF photometry from ALERCE API. - https://alerceapi.readthedocs.io/ """ @@ -10,110 +8,111 @@ import logging import os import numpy as np -from flows import ztf, api, load_config -from flows.plots import plt +import matplotlib.pyplot as plt +from tendrils import api +from tendrils.utils import load_config, ztf + -#-------------------------------------------------------------------------------------------------- def main(): - # Parse command line arguments: - parser = argparse.ArgumentParser(description='Download ZTF photometry.') - parser.add_argument('-d', '--debug', help='Print debug messages.', action='store_true') - parser.add_argument('-q', '--quiet', help='Only report warnings and errors.', action='store_true') - parser.add_argument('-t', '--target', type=str, default=None, help='Target to download ZTF photometry for.') - parser.add_argument('-o', '--output', type=str, default=None, help='Directory to save output to.') - args = parser.parse_args() - - # Set logging level: - logging_level = logging.INFO - if args.quiet: - logging_level = logging.WARNING - elif args.debug: - logging_level = logging.DEBUG - - # Setup logging: - formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') - console = logging.StreamHandler() - console.setFormatter(formatter) - logger = logging.getLogger(__name__) - if not logger.hasHandlers(): - logger.addHandler(console) - logger.setLevel(logging_level) - - if args.output is None: - config = load_config() - output_dir = config.get('ztf', 'output_photometry', fallback='.') - else: - output_dir = args.output - logger.info("Saving output to '%s'", output_dir) - - # Check that output directory exists: - if not os.path.isdir(output_dir): - parser.error(f"Output directory does not exist: '{output_dir}'") # noqa: G004 - - # Use API to get list of targets to process: - if args.target is None: - targets = api.get_targets() - else: - targets = [api.get_target(args.target)] - - # Colors used for the different filters in plots: - # I know purple is in the wrong end of the scale, but not much I can do - colors = {'gp': 'tab:green', 'rp': 'tab:red', 'ip': 'tab:purple'} - - # Loop through targets: - for tgt in targets: - logger.debug("Target: %s", tgt) - target_name = tgt['target_name'] - - # Paths to the files to be updated: - ztf_lightcurve_path = os.path.join(output_dir, f'{target_name:s}-ztf.ecsv') - ztf_plot_path = os.path.join(output_dir, f'{target_name:s}-ztf.png') - - # If there is no ZTF id, there is no need to try: - # If an old file exists then delete it. - if tgt['ztf_id'] is None: - if os.path.isfile(ztf_lightcurve_path): - os.remove(ztf_lightcurve_path) - if os.path.isfile(ztf_plot_path): - os.remove(ztf_plot_path) - continue - - # Download ZTF photometry as Astropy Table: - tab = ztf.download_ztf_photometry(tgt['targetid']) - logger.debug("ZTF Photometry:\n%s", tab) - if tab is None or len(tab) == 0: - if os.path.isfile(ztf_lightcurve_path): - os.remove(ztf_lightcurve_path) - if os.path.isfile(ztf_plot_path): - os.remove(ztf_plot_path) - continue - - # Write table to file: - tab.write(ztf_lightcurve_path, format='ascii.ecsv', delimiter=',') - - # Find time of maxmimum and 14 days from that: - indx_min = np.argmin(tab['mag']) - maximum_mjd = tab['time'][indx_min] - fortnight_mjd = maximum_mjd + 14 - - # Get LC data out and save as CSV files - fig, ax = plt.subplots() - ax.axvline(maximum_mjd, ls='--', c='k', lw=0.5, label='Maximum') - ax.axvline(fortnight_mjd, ls='--', c='0.5', lw=0.5, label='+14 days') - for fid in np.unique(tab['photfilter']): - col = colors[fid] - band = tab[tab['photfilter'] == fid] - ax.errorbar(band['time'], band['mag'], band['mag_err'], - color=col, ls='-', lw=0.5, marker='.', label=fid) - - ax.invert_yaxis() - ax.set_title(target_name) - ax.set_xlabel('Time (MJD)') - ax.set_ylabel('Magnitude') - ax.legend() - fig.savefig(ztf_plot_path, format='png', bbox_inches='tight') - plt.close(fig) - -#-------------------------------------------------------------------------------------------------- + # Parse command line arguments: + parser = argparse.ArgumentParser(description='Download ZTF photometry.') + parser.add_argument('-d', '--debug', help='Print debug messages.', action='store_true') + parser.add_argument('-q', '--quiet', help='Only report warnings and errors.', action='store_true') + parser.add_argument('-t', '--target', type=str, default=None, help='Target to download ZTF photometry for.') + parser.add_argument('-o', '--output', type=str, default=None, help='Directory to save output to.') + args = parser.parse_args() + + # Set logging level: + logging_level = logging.INFO + if args.quiet: + logging_level = logging.WARNING + elif args.debug: + logging_level = logging.DEBUG + + # Setup logging: + formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') + console = logging.StreamHandler() + console.setFormatter(formatter) + logger = logging.getLogger(__name__) + if not logger.hasHandlers(): + logger.addHandler(console) + logger.setLevel(logging_level) + + if args.output is None: + config = load_config() + output_dir = config.get('ztf', 'output_photometry', fallback='.') + else: + output_dir = args.output + logger.info("Saving output to '%s'", output_dir) + + # Check that output directory exists: + if not os.path.isdir(output_dir): + parser.error(f"Output directory does not exist: '{output_dir}'") # noqa: G004 + + # Use API to get list of targets to process: + if args.target is None: + targets = api.get_targets() + else: + targets = [api.get_target(args.target)] + + # Colors used for the different filters in plots: + # I know purple is in the wrong end of the scale, but not much I can do + colors = {'gp': 'tab:green', 'rp': 'tab:red', 'ip': 'tab:purple'} + + # Loop through targets: + for tgt in targets: + logger.debug("Target: %s", tgt) + target_name = tgt['target_name'] + + # Paths to the files to be updated: + ztf_lightcurve_path = os.path.join(output_dir, f'{target_name:s}-ztf.ecsv') + ztf_plot_path = os.path.join(output_dir, f'{target_name:s}-ztf.png') + + # If there is no ZTF id, there is no need to try: + # If an old file exists then delete it. + if tgt['ztf_id'] is None: + if os.path.isfile(ztf_lightcurve_path): + os.remove(ztf_lightcurve_path) + if os.path.isfile(ztf_plot_path): + os.remove(ztf_plot_path) + continue + + # Download ZTF photometry as Astropy Table: + tab = ztf.download_ztf_photometry(tgt['targetid']) + logger.debug("ZTF Photometry:\n%s", tab) + if tab is None or len(tab) == 0: + if os.path.isfile(ztf_lightcurve_path): + os.remove(ztf_lightcurve_path) + if os.path.isfile(ztf_plot_path): + os.remove(ztf_plot_path) + continue + + # Write table to file: + tab.write(ztf_lightcurve_path, format='ascii.ecsv', delimiter=',') + + # Find time of maxmimum and 14 days from that: + indx_min = np.argmin(tab['mag']) + maximum_mjd = tab['time'][indx_min] + fortnight_mjd = maximum_mjd + 14 + + # Get LC data out and save as CSV files + fig, ax = plt.subplots() + ax.axvline(maximum_mjd, ls='--', c='k', lw=0.5, label='Maximum') + ax.axvline(fortnight_mjd, ls='--', c='0.5', lw=0.5, label='+14 days') + for fid in np.unique(tab['photfilter']): + col = colors[fid] + band = tab[tab['photfilter'] == fid] + ax.errorbar(band['time'], band['mag'], band['mag_err'], color=col, ls='-', lw=0.5, marker='.', label=fid) + + ax.invert_yaxis() + ax.set_title(target_name) + ax.set_xlabel('Time (MJD)') + ax.set_ylabel('Magnitude') + ax.legend() + fig.savefig(ztf_plot_path, format='png', bbox_inches='tight') + plt.close(fig) + + +# -------------------------------------------------------------------------------------------------- if __name__ == '__main__': - main() + main() diff --git a/run_ingest.py b/run_ingest.py index 406f9ad..a271b0c 100644 --- a/run_ingest.py +++ b/run_ingest.py @@ -5,7 +5,7 @@ This code is obviously only meant to run on the central Flows systems, and will not work outside of that environment. - +@TODO: Refactor Database out of this script. .. codeauthor:: Rasmus Handberg """ @@ -27,568 +27,564 @@ from flows.load_image import load_image from flows.utilities import get_filehash -#-------------------------------------------------------------------------------------------------- + +# -------------------------------------------------------------------------------------------------- def flows_get_archive_from_path(fname, archives_list=None): - """ - Translate full path into AADC archive identifier and relative path. - - It is highly recommended to provide the list with that call - to this function since it will involve a query to the database - at every call. - """ - - archive = None - relpath = None - - # Get list of archives, if not provided with call: - if archives_list is None: - with AADC_DB() as db: - db.cursor.execute("SELECT archive,path FROM aadc.files_archives;") - archives_list = db.cursor.fetchall() - - # Make sure folder is absolute path - folder = os.path.abspath(fname) - - # Loop through the defined archives and find one that matches: - for opt in archives_list: - archive_path = opt['path'] - if archive_path is not None and archive_path != '': - archive_path = archive_path.rstrip('/\\') + os.path.sep - if folder.startswith(archive_path): - archive = int(opt['archive']) - relpath = folder[len(archive_path):].replace('\\', '/') - break - - # We did not find anything: - if archive is None: - raise RuntimeError("File not in registred archive") - - return archive, relpath - -#-------------------------------------------------------------------------------------------------- + """ + Translate full path into AADC archive identifier and relative path. + + It is highly recommended to provide the list with that call + to this function since it will involve a query to the database + at every call. + """ + + archive = None + relpath = None + + # Get list of archives, if not provided with call: + if archives_list is None: + with AADC_DB() as db: + db.cursor.execute("SELECT archive,path FROM aadc.files_archives;") + archives_list = db.cursor.fetchall() + + # Make sure folder is absolute path + folder = os.path.abspath(fname) + + # Loop through the defined archives and find one that matches: + for opt in archives_list: + archive_path = opt['path'] + if archive_path is not None and archive_path != '': + archive_path = archive_path.rstrip('/\\') + os.path.sep + if folder.startswith(archive_path): + archive = int(opt['archive']) + relpath = folder[len(archive_path):].replace('\\', '/') + break + + # We did not find anything: + if archive is None: + raise RuntimeError("File not in registred archive") + + return archive, relpath + + +# -------------------------------------------------------------------------------------------------- def optipng(fpath): - os.system('optipng -preserve -quiet "%s"' % fpath) + os.system('optipng -preserve -quiet "%s"' % fpath) + -#-------------------------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------------------------- class CounterFilter(logging.Filter): - """ - A logging filter which counts the number of log records in each level. - """ + """ + A logging filter which counts the number of log records in each level. + """ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.counter = defaultdict(int) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.counter = defaultdict(int) - def filter(self, record): # noqa: A003 - self.counter[record.levelname] += 1 - return True + def filter(self, record): # noqa: A003 + self.counter[record.levelname] += 1 + return True -#-------------------------------------------------------------------------------------------------- + +# -------------------------------------------------------------------------------------------------- def create_plot(filepath, target_coord=None, target_position=None): + output_fpath = os.path.abspath(re.sub(r'\.fits(\.gz)?$', '', filepath) + '.png') - output_fpath = os.path.abspath(re.sub(r'\.fits(\.gz)?$', '', filepath) + '.png') + img = load_image(filepath, target_coord=target_coord) - img = load_image(filepath, target_coord=target_coord) + fig = plt.figure(figsize=(12, 12)) + ax = fig.add_subplot(111) + plot_image(img.clean, ax=ax, scale='linear', percentile=[5, 99], cbar='right') + if target_position is not None: + ax.scatter(target_position[0], target_position[1], marker='+', s=20, c='r', label='Target') + fig.savefig(output_fpath, bbox_inches='tight') + plt.close(fig) - fig = plt.figure(figsize=(12,12)) - ax = fig.add_subplot(111) - plot_image(img.clean, ax=ax, scale='linear', percentile=[5, 99], cbar='right') - if target_position is not None: - ax.scatter(target_position[0], target_position[1], marker='+', s=20, c='r', label='Target') - fig.savefig(output_fpath, bbox_inches='tight') - plt.close(fig) + optipng(output_fpath) - optipng(output_fpath) -#-------------------------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------------------------- def ingest_from_inbox(): - - rootdir_inbox = '/flows/inbox' - rootdir = '/flows/archive' - - logger = logging.getLogger(__name__) - - # Check that root directories are available: - if not os.path.isdir(rootdir_inbox): - raise FileNotFoundError("INBOX does not exists") - if not os.path.isdir(rootdir): - raise FileNotFoundError("ARCHIVE does not exists") - - with AADC_DB() as db: - # Get list of archives: - db.cursor.execute("SELECT archive,path FROM aadc.files_archives;") - archives_list = db.cursor.fetchall() - - # Get list of all available filters: - db.cursor.execute("SELECT photfilter FROM flows.photfilters;") - all_filters = set([row['photfilter'] for row in db.cursor.fetchall()]) - - for inputtype in ('science', 'templates', 'subtracted', 'replace'): # - for fpath in glob.iglob(os.path.join(rootdir_inbox, '*', inputtype, '*')): - logger.info("="*72) - logger.info(fpath) - - # Find the uploadlog corresponding to this file: - db.cursor.execute("SELECT logid FROM flows.uploadlog WHERE uploadpath=%s;", [os.path.relpath(fpath, rootdir_inbox)]) - row = db.cursor.fetchone() - if row is not None: - uploadlogid = row['logid'] - else: - uploadlogid = None - logger.info("Uploadlog ID: %s", uploadlogid) - - # Only accept FITS file, or already compressed FITS files: - if not fpath.endswith('.fits') and not fpath.endswith('.fits.gz'): - if uploadlogid: - db.cursor.execute("UPDATE flows.uploadlog SET status='Invalid file type' WHERE logid=%s;", [uploadlogid]) - db.conn.commit() - logger.error("Invalid file type: %s", os.path.relpath(fpath, rootdir_inbox)) - continue - - # Get the name of the directory: - # Not pretty, but it works... - target_dirname = fpath[len(rootdir_inbox)+1:] - target_dirname = target_dirname.split(os.path.sep)[0] - - # Convert directory name to target - db.cursor.execute("SELECT targetid,target_name,ra,decl FROM flows.targets WHERE target_name=%s;", [target_dirname]) - row = db.cursor.fetchone() - if row is None: - logger.error('Could not find target: %s', target_dirname) - continue - targetid = row['targetid'] - targetname = row['target_name'] - target_radec = [[row['ra'], row['decl']]] - target_coord = coords.SkyCoord( - ra=row['ra'], - dec=row['decl'], - unit='deg', - frame='icrs') - - if not fpath.endswith('.gz'): - # Gzip the FITS file: - with open(fpath, 'rb') as f_in: - with gzip.open(fpath + '.gz', 'wb') as f_out: - f_out.writelines(f_in) - - # We should now have a Gzip file instead: - if os.path.isfile(fpath) and os.path.isfile(fpath + '.gz'): - # Update the log of this file: - if uploadlogid: - db.cursor.execute("UPDATE flows.uploadlog SET uploadpath=%s WHERE logid=%s;", [os.path.relpath(fpath+'.gz', rootdir_inbox), uploadlogid]) - db.conn.commit() - - os.remove(fpath) - fpath += '.gz' - else: - raise RuntimeError("Gzip file was not created correctly") - - version = 1 - if inputtype == 'science': - newpath = os.path.join(rootdir, targetname, os.path.basename(fpath)) - datatype = 1 - elif inputtype == 'templates': - newpath = os.path.join(rootdir, targetname, inputtype, os.path.basename(fpath)) - datatype = 3 - elif inputtype == 'subtracted': - newpath = os.path.join(rootdir, targetname, inputtype, os.path.basename(fpath)) - datatype = 4 - - original_fname = os.path.basename(fpath).replace('diff.fits', '.fits') - db.cursor.execute("SELECT fileid FROM flows.files WHERE targetid=%s AND datatype=1 AND path LIKE %s;", [targetid, '%/' + original_fname]) - subtracted_original_fileid = db.cursor.fetchone() - if subtracted_original_fileid is None: - if uploadlogid: - db.cursor.execute("UPDATE flows.uploadlog SET status='original science image not found' WHERE logid=%s;", [uploadlogid]) - db.conn.commit() - logger.error("ORIGINAL SCIENCE IMAGE COULD NOT BE FOUND: %s", os.path.basename(fpath)) - continue - else: - subtracted_original_fileid = subtracted_original_fileid[0] - - elif inputtype == 'replace': - bname = os.path.basename(fpath) - m = re.match(r'^(\d+)_v(\d+)\.fits(\.gz)?$', bname) - if m: - replaceid = int(m.group(1)) - version = int(m.group(2)) - - db.cursor.execute("SELECT datatype,path,version FROM flows.files WHERE fileid=%s;", [replaceid]) - row = db.cursor.fetchone() - if row is None: - logger.error("Unknown fileid to be replaced: %s", bname) - continue - datatype = row['datatype'] - subdir = {1: '', 4: 'subtracted'}[datatype] - - if version != row['version'] + 1: - logger.error("Mismatch in versions: old=%d, new=%d", row['version'], version) - continue - - newfilename = re.sub(r'(_v\d+)?\.fits(\.gz)?$', r'_v{version:d}.fits\2'.format(version=version), os.path.basename(row['path'])) - newpath = os.path.join(rootdir, targetname, subdir, newfilename) - - if datatype == 4: - db.cursor.execute("SELECT associd FROM flows.files_cross_assoc INNER JOIN flows.files ON files.fileid=files_cross_assoc.associd WHERE files_cross_assoc.fileid=%s AND datatype=1;", [replaceid]) - subtracted_original_fileid = db.cursor.fetchone() - if subtracted_original_fileid is None: - if uploadlogid: - db.cursor.execute("UPDATE flows.uploadlog SET status='original science image not found' WHERE logid=%s;", [uploadlogid]) - db.conn.commit() - logger.error("ORIGINAL SCIENCE IMAGE COULD NOT BE FOUND: %s", os.path.basename(fpath)) - continue - else: - subtracted_original_fileid = subtracted_original_fileid[0] - else: - logger.error("Invalid replace file name: %s", bname) - continue - - else: - raise RuntimeError("Not understood, Captain") - - logger.info(newpath) - - if os.path.exists(newpath): - logger.error("Already exists") - if uploadlogid: - db.cursor.execute("UPDATE flows.uploadlog SET status='Already exists: file name' WHERE logid=%s;", [uploadlogid]) - db.conn.commit() - continue - - archive, relpath = flows_get_archive_from_path(newpath, archives_list) - - db.cursor.execute("SELECT fileid FROM flows.files WHERE archive=%s AND path=%s;", [archive, relpath]) - if db.cursor.fetchone() is not None: - logger.error("ALREADY DONE") - continue - - # Calculate filehash of the file being stored: - filehash = get_filehash(fpath) - - # Check that the file does not already exist: - db.cursor.execute("SELECT fileid FROM flows.files WHERE filehash=%s;", [filehash]) - if db.cursor.fetchone() is not None: - logger.error("ALREADY DONE: Filehash") - if uploadlogid: - db.cursor.execute("UPDATE flows.uploadlog SET status='Already exists: filehash' WHERE logid=%s;", [uploadlogid]) - db.conn.commit() - continue - - # Try to load the image using the same function as the pipeline would: - try: - img = load_image(fpath, target_coord=target_coord) - except Exception as e: # pragma: no cover - logger.exception("Could not load FITS image") - if uploadlogid: - errmsg = str(e) if hasattr(e, 'message') else str(e.message) - db.cursor.execute("UPDATE flows.uploadlog SET status=%s WHERE logid=%s;", ['Load Image Error: ' + errmsg, uploadlogid]) - db.conn.commit() - continue - - # Use the WCS in the file to calculate the pixel-positon of the target: - try: - target_pixels = img.wcs.all_world2pix(target_radec, 0).flatten() - except: # noqa: E722, pragma: no cover - logger.exception("Could not find target position using the WCS.") - if uploadlogid: - errmsg = "Could not find target position using the WCS." - db.cursor.execute("UPDATE flows.uploadlog SET status=%s WHERE logid=%s;", [errmsg, uploadlogid]) - db.conn.commit() - continue - - # Check that the position of the target actually falls within - # the pixels of the image: - if target_pixels[0] < -0.5 or target_pixels[1] < -0.5 \ - or target_pixels[0] > img.shape[1]-0.5 or target_pixels[1] > img.shape[0]-0.5: - logger.error("Target position does not fall within image. Check the WCS.") - if uploadlogid: - errmsg = "Target position does not fall within image. Check the WCS." - db.cursor.execute("UPDATE flows.uploadlog SET status=%s WHERE logid=%s;", [errmsg, uploadlogid]) - db.conn.commit() - continue - - # Check that the site was found: - if img.site is None or img.site['siteid'] is None: - logger.error("Unknown SITE") - if uploadlogid: - errmsg = "Unknown site" - db.cursor.execute("UPDATE flows.uploadlog SET status=%s WHERE logid=%s;", [errmsg, uploadlogid]) - db.conn.commit() - continue - - # Check that the extracted photometric filter is valid: - if img.photfilter not in all_filters: - logger.error("Unknown PHOTFILTER: %s", img.photfilter) - if uploadlogid: - errmsg = "Unknown PHOTFILTER: '" + str(img.photfilter) + "'" - db.cursor.execute("UPDATE flows.uploadlog SET status=%s WHERE logid=%s;", [errmsg, uploadlogid]) - db.conn.commit() - continue - - # Do a deep check to ensure that there is not already another file with the same - # properties (target, datatype, site, filter) taken at the same time: - # TODO: Look at the actual overlap with the database, instead of just overlap - # with the central value. This way of doing it is more forgiving. - obstime = img.obstime.utc.mjd - if inputtype != 'replace': - db.cursor.execute("SELECT fileid FROM flows.files WHERE targetid=%s AND datatype=%s AND site=%s AND photfilter=%s AND obstime BETWEEN %s AND %s;", [ - targetid, - datatype, - img.site['siteid'], - img.photfilter, - obstime - 0.5 * img.exptime/86400, - obstime + 0.5 * img.exptime/86400, - ]) - if db.cursor.fetchone() is not None: - logger.error("ALREADY DONE: Deep check") - if uploadlogid: - db.cursor.execute("UPDATE flows.uploadlog SET status='Already exists: deep check' WHERE logid=%s;", [uploadlogid]) - db.conn.commit() - continue - - try: - # Copy the file to its new home: - os.makedirs(os.path.dirname(newpath), exist_ok=True) - shutil.copy(fpath, newpath) - - # Set file and directory permissions: - # TODO: Can this not be handled in a more elegant way? - os.chmod(os.path.dirname(newpath), 0o2750) - os.chmod(newpath, 0o0440) - - filesize = os.path.getsize(fpath) - - if not fpath.endswith('-e00.fits'): - create_plot(newpath, target_coord=target_coord, target_position=target_pixels) - - db.cursor.execute("INSERT INTO flows.files (archive,path,targetid,datatype,site,filesize,filehash,obstime,photfilter,exptime,version,available) VALUES (%(archive)s,%(relpath)s,%(targetid)s,%(datatype)s,%(site)s,%(filesize)s,%(filehash)s,%(obstime)s,%(photfilter)s,%(exptime)s,%(version)s,1) RETURNING fileid;", { - 'archive': archive, - 'relpath': relpath, - 'targetid': targetid, - 'datatype': datatype, - 'site': img.site['siteid'], - 'filesize': filesize, - 'filehash': filehash, - 'obstime': obstime, - 'photfilter': img.photfilter, - 'exptime': img.exptime, - 'version': version - }) - fileid = db.cursor.fetchone()[0] - - if datatype == 4: - db.cursor.execute("INSERT INTO flows.files_cross_assoc (fileid,associd) VALUES (%s,%s);", [fileid, subtracted_original_fileid]) - - if inputtype == 'replace': - db.cursor.execute("UPDATE flows.files SET newest_version=FALSE WHERE fileid=%s;", [replaceid]) - - if uploadlogid: - db.cursor.execute("UPDATE flows.uploadlog SET fileid=%s,status='ok' WHERE logid=%s;", [fileid, uploadlogid]) - - db.conn.commit() - except: # noqa: E722, pragma: no cover - db.conn.rollback() - if os.path.exists(newpath): - os.remove(newpath) - raise - else: - logger.info("DELETE THE ORIGINAL FILE") - if os.path.isfile(newpath): - os.remove(fpath) - if uploadlogid: - db.cursor.execute("UPDATE flows.uploadlog SET uploadpath=NULL WHERE logid=%s;", [uploadlogid]) - db.conn.commit() - -#-------------------------------------------------------------------------------------------------- + rootdir_inbox = '/flows/inbox' + rootdir = '/flows/archive' + + logger = logging.getLogger(__name__) + + # Check that root directories are available: + if not os.path.isdir(rootdir_inbox): + raise FileNotFoundError("INBOX does not exists") + if not os.path.isdir(rootdir): + raise FileNotFoundError("ARCHIVE does not exists") + + with AADC_DB() as db: + # Get list of archives: + db.cursor.execute("SELECT archive,path FROM aadc.files_archives;") + archives_list = db.cursor.fetchall() + + # Get list of all available filters: + db.cursor.execute("SELECT photfilter FROM flows.photfilters;") + all_filters = set([row['photfilter'] for row in db.cursor.fetchall()]) + + for inputtype in ('science', 'templates', 'subtracted', 'replace'): # + for fpath in glob.iglob(os.path.join(rootdir_inbox, '*', inputtype, '*')): + logger.info("=" * 72) + logger.info(fpath) + + # Find the uploadlog corresponding to this file: + db.cursor.execute("SELECT logid FROM flows.uploadlog WHERE uploadpath=%s;", + [os.path.relpath(fpath, rootdir_inbox)]) + row = db.cursor.fetchone() + if row is not None: + uploadlogid = row['logid'] + else: + uploadlogid = None + logger.info("Uploadlog ID: %s", uploadlogid) + + # Only accept FITS file, or already compressed FITS files: + if not fpath.endswith('.fits') and not fpath.endswith('.fits.gz'): + if uploadlogid: + db.cursor.execute("UPDATE flows.uploadlog SET status='Invalid file type' WHERE logid=%s;", + [uploadlogid]) + db.conn.commit() + logger.error("Invalid file type: %s", os.path.relpath(fpath, rootdir_inbox)) + continue + + # Get the name of the directory: + # Not pretty, but it works... + target_dirname = fpath[len(rootdir_inbox) + 1:] + target_dirname = target_dirname.split(os.path.sep)[0] + + # Convert directory name to target + db.cursor.execute("SELECT targetid,target_name,ra,decl FROM flows.targets WHERE target_name=%s;", + [target_dirname]) + row = db.cursor.fetchone() + if row is None: + logger.error('Could not find target: %s', target_dirname) + continue + targetid = row['targetid'] + targetname = row['target_name'] + target_radec = [[row['ra'], row['decl']]] + target_coord = coords.SkyCoord(ra=row['ra'], dec=row['decl'], unit='deg', frame='icrs') + + if not fpath.endswith('.gz'): + # Gzip the FITS file: + with open(fpath, 'rb') as f_in: + with gzip.open(fpath + '.gz', 'wb') as f_out: + f_out.writelines(f_in) + + # We should now have a Gzip file instead: + if os.path.isfile(fpath) and os.path.isfile(fpath + '.gz'): + # Update the log of this file: + if uploadlogid: + db.cursor.execute("UPDATE flows.uploadlog SET uploadpath=%s WHERE logid=%s;", + [os.path.relpath(fpath + '.gz', rootdir_inbox), uploadlogid]) + db.conn.commit() + + os.remove(fpath) + fpath += '.gz' + else: + raise RuntimeError("Gzip file was not created correctly") + + version = 1 + if inputtype == 'science': + newpath = os.path.join(rootdir, targetname, os.path.basename(fpath)) + datatype = 1 + elif inputtype == 'templates': + newpath = os.path.join(rootdir, targetname, inputtype, os.path.basename(fpath)) + datatype = 3 + elif inputtype == 'subtracted': + newpath = os.path.join(rootdir, targetname, inputtype, os.path.basename(fpath)) + datatype = 4 + + original_fname = os.path.basename(fpath).replace('diff.fits', '.fits') + db.cursor.execute( + "SELECT fileid FROM flows.files WHERE targetid=%s AND datatype=1 AND path LIKE %s;", + [targetid, '%/' + original_fname]) + subtracted_original_fileid = db.cursor.fetchone() + if subtracted_original_fileid is None: + if uploadlogid: + db.cursor.execute( + "UPDATE flows.uploadlog SET status='original science image not found' WHERE logid=%s;", + [uploadlogid]) + db.conn.commit() + logger.error("ORIGINAL SCIENCE IMAGE COULD NOT BE FOUND: %s", os.path.basename(fpath)) + continue + else: + subtracted_original_fileid = subtracted_original_fileid[0] + + elif inputtype == 'replace': + bname = os.path.basename(fpath) + m = re.match(r'^(\d+)_v(\d+)\.fits(\.gz)?$', bname) + if m: + replaceid = int(m.group(1)) + version = int(m.group(2)) + + db.cursor.execute("SELECT datatype,path,version FROM flows.files WHERE fileid=%s;", [replaceid]) + row = db.cursor.fetchone() + if row is None: + logger.error("Unknown fileid to be replaced: %s", bname) + continue + datatype = row['datatype'] + subdir = {1: '', 4: 'subtracted'}[datatype] + + if version != row['version'] + 1: + logger.error("Mismatch in versions: old=%d, new=%d", row['version'], version) + continue + + newfilename = re.sub(r'(_v\d+)?\.fits(\.gz)?$', r'_v{version:d}.fits\2'.format(version=version), + os.path.basename(row['path'])) + newpath = os.path.join(rootdir, targetname, subdir, newfilename) + + if datatype == 4: + db.cursor.execute( + "SELECT associd FROM flows.files_cross_assoc INNER JOIN flows.files ON files.fileid=files_cross_assoc.associd WHERE files_cross_assoc.fileid=%s AND datatype=1;", + [replaceid]) + subtracted_original_fileid = db.cursor.fetchone() + if subtracted_original_fileid is None: + if uploadlogid: + db.cursor.execute( + "UPDATE flows.uploadlog SET status='original science image not found' WHERE logid=%s;", + [uploadlogid]) + db.conn.commit() + logger.error("ORIGINAL SCIENCE IMAGE COULD NOT BE FOUND: %s", os.path.basename(fpath)) + continue + else: + subtracted_original_fileid = subtracted_original_fileid[0] + else: + logger.error("Invalid replace file name: %s", bname) + continue + + else: + raise RuntimeError("Not understood, Captain") + + logger.info(newpath) + + if os.path.exists(newpath): + logger.error("Already exists") + if uploadlogid: + db.cursor.execute( + "UPDATE flows.uploadlog SET status='Already exists: file name' WHERE logid=%s;", + [uploadlogid]) + db.conn.commit() + continue + + archive, relpath = flows_get_archive_from_path(newpath, archives_list) + + db.cursor.execute("SELECT fileid FROM flows.files WHERE archive=%s AND path=%s;", [archive, relpath]) + if db.cursor.fetchone() is not None: + logger.error("ALREADY DONE") + continue + + # Calculate filehash of the file being stored: + filehash = get_filehash(fpath) + + # Check that the file does not already exist: + db.cursor.execute("SELECT fileid FROM flows.files WHERE filehash=%s;", [filehash]) + if db.cursor.fetchone() is not None: + logger.error("ALREADY DONE: Filehash") + if uploadlogid: + db.cursor.execute( + "UPDATE flows.uploadlog SET status='Already exists: filehash' WHERE logid=%s;", + [uploadlogid]) + db.conn.commit() + continue + + # Try to load the image using the same function as the pipeline would: + try: + img = load_image(fpath, target_coord=target_coord) + except Exception as e: # pragma: no cover + logger.exception("Could not load FITS image") + if uploadlogid: + errmsg = str(e) if hasattr(e, 'message') else str(e.message) + db.cursor.execute("UPDATE flows.uploadlog SET status=%s WHERE logid=%s;", + ['Load Image Error: ' + errmsg, uploadlogid]) + db.conn.commit() + continue + + # Use the WCS in the file to calculate the pixel-positon of the target: + try: + target_pixels = img.wcs.all_world2pix(target_radec, 0).flatten() + except: # noqa: E722, pragma: no cover + logger.exception("Could not find target position using the WCS.") + if uploadlogid: + errmsg = "Could not find target position using the WCS." + db.cursor.execute("UPDATE flows.uploadlog SET status=%s WHERE logid=%s;", [errmsg, uploadlogid]) + db.conn.commit() + continue + + # Check that the position of the target actually falls within + # the pixels of the image: + if target_pixels[0] < -0.5 or target_pixels[1] < -0.5 or target_pixels[0] > img.shape[1] - 0.5 or \ + target_pixels[1] > img.shape[0] - 0.5: + logger.error("Target position does not fall within image. Check the WCS.") + if uploadlogid: + errmsg = "Target position does not fall within image. Check the WCS." + db.cursor.execute("UPDATE flows.uploadlog SET status=%s WHERE logid=%s;", [errmsg, uploadlogid]) + db.conn.commit() + continue + + # Check that the site was found: + if img.site is None or img.site['siteid'] is None: + logger.error("Unknown SITE") + if uploadlogid: + errmsg = "Unknown site" + db.cursor.execute("UPDATE flows.uploadlog SET status=%s WHERE logid=%s;", [errmsg, uploadlogid]) + db.conn.commit() + continue + + # Check that the extracted photometric filter is valid: + if img.photfilter not in all_filters: + logger.error("Unknown PHOTFILTER: %s", img.photfilter) + if uploadlogid: + errmsg = "Unknown PHOTFILTER: '" + str(img.photfilter) + "'" + db.cursor.execute("UPDATE flows.uploadlog SET status=%s WHERE logid=%s;", [errmsg, uploadlogid]) + db.conn.commit() + continue + + # Do a deep check to ensure that there is not already another file with the same + # properties (target, datatype, site, filter) taken at the same time: + # TODO: Look at the actual overlap with the database, instead of just overlap + # with the central value. This way of doing it is more forgiving. + obstime = img.obstime.utc.mjd + if inputtype != 'replace': + db.cursor.execute( + "SELECT fileid FROM flows.files WHERE targetid=%s AND datatype=%s AND site=%s AND photfilter=%s AND obstime BETWEEN %s AND %s;", + [targetid, datatype, img.site['siteid'], img.photfilter, obstime - 0.5 * img.exptime / 86400, + obstime + 0.5 * img.exptime / 86400, ]) + if db.cursor.fetchone() is not None: + logger.error("ALREADY DONE: Deep check") + if uploadlogid: + db.cursor.execute( + "UPDATE flows.uploadlog SET status='Already exists: deep check' WHERE logid=%s;", + [uploadlogid]) + db.conn.commit() + continue + + try: + # Copy the file to its new home: + os.makedirs(os.path.dirname(newpath), exist_ok=True) + shutil.copy(fpath, newpath) + + # Set file and directory permissions: + # TODO: Can this not be handled in a more elegant way? + os.chmod(os.path.dirname(newpath), 0o2750) + os.chmod(newpath, 0o0440) + + filesize = os.path.getsize(fpath) + + if not fpath.endswith('-e00.fits'): + create_plot(newpath, target_coord=target_coord, target_position=target_pixels) + + db.cursor.execute( + "INSERT INTO flows.files (archive,path,targetid,datatype,site,filesize,filehash,obstime,photfilter,exptime,version,available) VALUES (%(archive)s,%(relpath)s,%(targetid)s,%(datatype)s,%(site)s,%(filesize)s,%(filehash)s,%(obstime)s,%(photfilter)s,%(exptime)s,%(version)s,1) RETURNING fileid;", + {'archive': archive, 'relpath': relpath, 'targetid': targetid, 'datatype': datatype, + 'site': img.site['siteid'], 'filesize': filesize, 'filehash': filehash, 'obstime': obstime, + 'photfilter': img.photfilter, 'exptime': img.exptime, 'version': version}) + fileid = db.cursor.fetchone()[0] + + if datatype == 4: + db.cursor.execute("INSERT INTO flows.files_cross_assoc (fileid,associd) VALUES (%s,%s);", + [fileid, subtracted_original_fileid]) + + if inputtype == 'replace': + db.cursor.execute("UPDATE flows.files SET newest_version=FALSE WHERE fileid=%s;", [replaceid]) + + if uploadlogid: + db.cursor.execute("UPDATE flows.uploadlog SET fileid=%s,status='ok' WHERE logid=%s;", + [fileid, uploadlogid]) + + db.conn.commit() + except: # noqa: E722, pragma: no cover + db.conn.rollback() + if os.path.exists(newpath): + os.remove(newpath) + raise + else: + logger.info("DELETE THE ORIGINAL FILE") + if os.path.isfile(newpath): + os.remove(fpath) + if uploadlogid: + db.cursor.execute("UPDATE flows.uploadlog SET uploadpath=NULL WHERE logid=%s;", [uploadlogid]) + db.conn.commit() + + +# -------------------------------------------------------------------------------------------------- def ingest_photometry_from_inbox(): - - rootdir_inbox = '/flows/inbox' - rootdir_archive = '/flows/archive_photometry' - - logger = logging.getLogger(__name__) - - # Check that root directories are available: - if not os.path.isdir(rootdir_inbox): - raise FileNotFoundError("INBOX does not exists") - if not os.path.isdir(rootdir_archive): - raise FileNotFoundError("ARCHIVE does not exists") - - with AADC_DB() as db: - # Get list of archives: - db.cursor.execute("SELECT archive,path FROM aadc.files_archives;") - archives_list = db.cursor.fetchall() - - for fpath in glob.iglob(os.path.join(rootdir_inbox, '*', 'photometry', '*')): - logger.info("="*72) - logger.info(fpath) - - # Find the uploadlog corresponding to this file: - db.cursor.execute("SELECT logid FROM flows.uploadlog WHERE uploadpath=%s;", [os.path.relpath(fpath, rootdir_inbox)]) - row = db.cursor.fetchone() - if row is not None: - uploadlogid = row['logid'] - else: - uploadlogid = None - - # Only accept FITS file, or already compressed FITS files: - if not fpath.endswith('.zip'): - if uploadlogid: - db.cursor.execute("UPDATE flows.uploadlog SET status='Invalid file type' WHERE logid=%s;", [uploadlogid]) - db.conn.commit() - logger.error("Invalid file type: %s", os.path.relpath(fpath, rootdir_inbox)) - continue - - # Get the name of the directory: - # Not pretty, but it works... - target_dirname = fpath[len(rootdir_inbox)+1:] - target_dirname = target_dirname.split(os.path.sep)[0] - - # Convert directory name to target - db.cursor.execute("SELECT targetid,target_name FROM flows.targets WHERE target_name=%s;", [target_dirname]) - row = db.cursor.fetchone() - if row is None: - logger.error('Could not find target: %s', target_dirname) - continue - targetid = row['targetid'] - targetname = row['target_name'] - - newpath = None - try: - with tempfile.TemporaryDirectory() as tmpdir: - # - tmpphotfile = os.path.join(tmpdir, 'photometry.ecsv') - - # Extract the ZIP file: - with ZipFile(fpath, mode='r') as myzip: - for member in myzip.infolist(): - # Remove any directory structure from the zip file: - if member.filename.endswith('/'): # member.is_dir() - continue - member.filename = os.path.basename(member.filename) - - # Due to security considerations, we only allow specific files - # to be extracted: - if member.filename == 'photometry.ecsv': - myzip.extract(member, path=tmpdir) - elif member.filename.endswith('.png') or member.filename.endswith('.log'): - myzip.extract(member, path=tmpdir) - - # Check that the photometry ECSV file at least exists: - if not os.path.isfile(tmpphotfile): - raise FileNotFoundError("Photometry is not found") - - # Load photometry table: - tab = Table.read(tmpphotfile, format='ascii.ecsv') - fileid_img = int(tab.meta['fileid']) - targetid_table = int(tab.meta['targetid']) - - assert targetid_table == targetid - - # Find out which version number to assign to file: - db.cursor.execute("SELECT MAX(files.version) AS latest_version FROM flows.files_cross_assoc fca INNER JOIN flows.files ON fca.fileid=files.fileid WHERE fca.associd=%s AND files.datatype=2;", [fileid_img,]) - latest_version = db.cursor.fetchone() - if latest_version[0] is None: - new_version = 1 - else: - new_version = latest_version[0] + 1 - - # Create a new path and filename that is slightly more descriptive: - newpath = os.path.join( - rootdir_archive, - targetname, - f'{fileid_img:05d}', - f'v{new_version:02d}', - f'photometry-{targetname:s}-{fileid_img:05d}-v{new_version:02d}.ecsv' - ) - logger.info(newpath) - - if os.path.exists(newpath): - logger.error("Already exists") - if uploadlogid: - db.cursor.execute("UPDATE flows.uploadlog SET status='Already exists: file name' WHERE logid=%s;", [uploadlogid]) - db.conn.commit() - continue - - archive, relpath = flows_get_archive_from_path(newpath, archives_list) - - db.cursor.execute("SELECT fileid FROM flows.files WHERE archive=%s AND path=%s;", [archive, relpath]) - if db.cursor.fetchone() is not None: - logger.error("ALREADY DONE") - continue - - db.cursor.execute("SELECT * FROM flows.files WHERE fileid=%s;", [fileid_img]) - row = db.cursor.fetchone() - site = row['site'] - - assert targetid == row['targetid'] - assert tab.meta['photfilter'] == row['photfilter'] - - # Optimize all the PNG files in the temp directory: - for f in glob.iglob(os.path.join(tmpdir, '*.png')): - optipng(f) - - # Copy the full directory to its new home: - shutil.copytree(tmpdir, os.path.dirname(newpath)) - os.rename(os.path.join(os.path.dirname(newpath), 'photometry.ecsv'), newpath) - - # Get information about file: - filesize = os.path.getsize(newpath) - filehash = get_filehash(newpath) - - db.cursor.execute("INSERT INTO flows.files (archive,path,targetid,datatype,site,filesize,filehash,obstime,photfilter,version,available) VALUES (%(archive)s,%(relpath)s,%(targetid)s,%(datatype)s,%(site)s,%(filesize)s,%(filehash)s,%(obstime)s,%(photfilter)s,%(version)s,1) RETURNING fileid;", { - 'archive': archive, - 'relpath': relpath, - 'targetid': targetid, - 'datatype': 2, - 'site': site, - 'filesize': filesize, - 'filehash': filehash, - 'obstime': tab.meta['obstime-bmjd'], - 'photfilter': tab.meta['photfilter'], - 'version': new_version - }) - fileid = db.cursor.fetchone()[0] - - # Add dependencies: - db.cursor.execute("INSERT INTO flows.files_cross_assoc (fileid,associd) VALUES (%s,%s);", [fileid, fileid_img]) - if tab.meta['template'] is not None: - db.cursor.execute("INSERT INTO flows.files_cross_assoc (fileid,associd) VALUES (%s,%s);", [fileid, tab.meta['template']]) - if tab.meta['diffimg'] is not None: - db.cursor.execute("INSERT INTO flows.files_cross_assoc (fileid,associd) VALUES (%s,%s);", [fileid, tab.meta['diffimg']]) - - indx_raw = (tab['starid'] == 0) - indx_sub = (tab['starid'] == -1) - indx_ref = (tab['starid'] > 0) - - frd = float(np.nanmax(tab[indx_ref]['mag'])) - if not np.isfinite(frd): - frd = None - - phot_summary = { - 'fileid_img': fileid_img, - 'fileid_phot': fileid, - 'fileid_template': tab.meta['template'], - 'fileid_diffimg': tab.meta['diffimg'], - 'targetid': targetid, - 'obstime': tab.meta['obstime-bmjd'], - 'photfilter': tab.meta['photfilter'], - 'mag_raw': float(tab[indx_raw]['mag']), - 'mag_raw_error': float(tab[indx_raw]['mag_error']), - 'mag_sub': None if not any(indx_sub) else float(tab[indx_sub]['mag']), - 'mag_sub_error': None if not any(indx_sub) else float(tab[indx_sub]['mag_error']), - 'zeropoint': float(tab.meta['zp']), - 'zeropoint_error': float(tab.meta['zp_error']), - 'zeropoint_diff': float(tab.meta['zp_diff']), - 'fwhm': float(tab.meta['fwhm'].value), - 'seeing': float(tab.meta['seeing'].value), - 'references_detected': int(np.sum(indx_ref)), - 'used_for_epsf': int(np.sum(tab['used_for_epsf'])), - 'faintest_reference_detected': frd, - 'pipeline_version': tab.meta['version'], - 'latest_version': new_version - } - - db.cursor.execute("""INSERT INTO flows.photometry_details ( + rootdir_inbox = '/flows/inbox' + rootdir_archive = '/flows/archive_photometry' + + logger = logging.getLogger(__name__) + + # Check that root directories are available: + if not os.path.isdir(rootdir_inbox): + raise FileNotFoundError("INBOX does not exists") + if not os.path.isdir(rootdir_archive): + raise FileNotFoundError("ARCHIVE does not exists") + + with AADC_DB() as db: + # Get list of archives: + db.cursor.execute("SELECT archive,path FROM aadc.files_archives;") + archives_list = db.cursor.fetchall() + + for fpath in glob.iglob(os.path.join(rootdir_inbox, '*', 'photometry', '*')): + logger.info("=" * 72) + logger.info(fpath) + + # Find the uploadlog corresponding to this file: + db.cursor.execute("SELECT logid FROM flows.uploadlog WHERE uploadpath=%s;", + [os.path.relpath(fpath, rootdir_inbox)]) + row = db.cursor.fetchone() + if row is not None: + uploadlogid = row['logid'] + else: + uploadlogid = None + + # Only accept FITS file, or already compressed FITS files: + if not fpath.endswith('.zip'): + if uploadlogid: + db.cursor.execute("UPDATE flows.uploadlog SET status='Invalid file type' WHERE logid=%s;", + [uploadlogid]) + db.conn.commit() + logger.error("Invalid file type: %s", os.path.relpath(fpath, rootdir_inbox)) + continue + + # Get the name of the directory: + # Not pretty, but it works... + target_dirname = fpath[len(rootdir_inbox) + 1:] + target_dirname = target_dirname.split(os.path.sep)[0] + + # Convert directory name to target + db.cursor.execute("SELECT targetid,target_name FROM flows.targets WHERE target_name=%s;", [target_dirname]) + row = db.cursor.fetchone() + if row is None: + logger.error('Could not find target: %s', target_dirname) + continue + targetid = row['targetid'] + targetname = row['target_name'] + + newpath = None + try: + with tempfile.TemporaryDirectory() as tmpdir: + # + tmpphotfile = os.path.join(tmpdir, 'photometry.ecsv') + + # Extract the ZIP file: + with ZipFile(fpath, mode='r') as myzip: + for member in myzip.infolist(): + # Remove any directory structure from the zip file: + if member.filename.endswith('/'): # member.is_dir() + continue + member.filename = os.path.basename(member.filename) + + # Due to security considerations, we only allow specific files + # to be extracted: + if member.filename == 'photometry.ecsv': + myzip.extract(member, path=tmpdir) + elif member.filename.endswith('.png') or member.filename.endswith('.log'): + myzip.extract(member, path=tmpdir) + + # Check that the photometry ECSV file at least exists: + if not os.path.isfile(tmpphotfile): + raise FileNotFoundError("Photometry is not found") + + # Load photometry table: + tab = Table.read(tmpphotfile, format='ascii.ecsv') + fileid_img = int(tab.meta['fileid']) + targetid_table = int(tab.meta['targetid']) + + assert targetid_table == targetid + + # Find out which version number to assign to file: + db.cursor.execute( + "SELECT MAX(files.version) AS latest_version FROM flows.files_cross_assoc fca INNER JOIN flows.files ON fca.fileid=files.fileid WHERE fca.associd=%s AND files.datatype=2;", + [fileid_img, ]) + latest_version = db.cursor.fetchone() + if latest_version[0] is None: + new_version = 1 + else: + new_version = latest_version[0] + 1 + + # Create a new path and filename that is slightly more descriptive: + newpath = os.path.join(rootdir_archive, targetname, f'{fileid_img:05d}', f'v{new_version:02d}', + f'photometry-{targetname:s}-{fileid_img:05d}-v{new_version:02d}.ecsv') + logger.info(newpath) + + if os.path.exists(newpath): + logger.error("Already exists") + if uploadlogid: + db.cursor.execute( + "UPDATE flows.uploadlog SET status='Already exists: file name' WHERE logid=%s;", + [uploadlogid]) + db.conn.commit() + continue + + archive, relpath = flows_get_archive_from_path(newpath, archives_list) + + db.cursor.execute("SELECT fileid FROM flows.files WHERE archive=%s AND path=%s;", + [archive, relpath]) + if db.cursor.fetchone() is not None: + logger.error("ALREADY DONE") + continue + + db.cursor.execute("SELECT * FROM flows.files WHERE fileid=%s;", [fileid_img]) + row = db.cursor.fetchone() + site = row['site'] + + assert targetid == row['targetid'] + assert tab.meta['photfilter'] == row['photfilter'] + + # Optimize all the PNG files in the temp directory: + for f in glob.iglob(os.path.join(tmpdir, '*.png')): + optipng(f) + + # Copy the full directory to its new home: + shutil.copytree(tmpdir, os.path.dirname(newpath)) + os.rename(os.path.join(os.path.dirname(newpath), 'photometry.ecsv'), newpath) + + # Get information about file: + filesize = os.path.getsize(newpath) + filehash = get_filehash(newpath) + + db.cursor.execute( + "INSERT INTO flows.files (archive,path,targetid,datatype,site,filesize,filehash,obstime,photfilter,version,available) VALUES (%(archive)s,%(relpath)s,%(targetid)s,%(datatype)s,%(site)s,%(filesize)s,%(filehash)s,%(obstime)s,%(photfilter)s,%(version)s,1) RETURNING fileid;", + {'archive': archive, 'relpath': relpath, 'targetid': targetid, 'datatype': 2, 'site': site, + 'filesize': filesize, 'filehash': filehash, 'obstime': tab.meta['obstime-bmjd'], + 'photfilter': tab.meta['photfilter'], 'version': new_version}) + fileid = db.cursor.fetchone()[0] + + # Add dependencies: + db.cursor.execute("INSERT INTO flows.files_cross_assoc (fileid,associd) VALUES (%s,%s);", + [fileid, fileid_img]) + if tab.meta['template'] is not None: + db.cursor.execute("INSERT INTO flows.files_cross_assoc (fileid,associd) VALUES (%s,%s);", + [fileid, tab.meta['template']]) + if tab.meta['diffimg'] is not None: + db.cursor.execute("INSERT INTO flows.files_cross_assoc (fileid,associd) VALUES (%s,%s);", + [fileid, tab.meta['diffimg']]) + + indx_raw = (tab['starid'] == 0) + indx_sub = (tab['starid'] == -1) + indx_ref = (tab['starid'] > 0) + + frd = float(np.nanmax(tab[indx_ref]['mag'])) + if not np.isfinite(frd): + frd = None + + phot_summary = {'fileid_img': fileid_img, 'fileid_phot': fileid, + 'fileid_template': tab.meta['template'], 'fileid_diffimg': tab.meta['diffimg'], + 'targetid': targetid, 'obstime': tab.meta['obstime-bmjd'], + 'photfilter': tab.meta['photfilter'], 'mag_raw': float(tab[indx_raw]['mag']), + 'mag_raw_error': float(tab[indx_raw]['mag_error']), + 'mag_sub': None if not any(indx_sub) else float(tab[indx_sub]['mag']), + 'mag_sub_error': None if not any(indx_sub) else float(tab[indx_sub]['mag_error']), + 'zeropoint': float(tab.meta['zp']), 'zeropoint_error': float(tab.meta['zp_error']), + 'zeropoint_diff': float(tab.meta['zp_diff']), 'fwhm': float(tab.meta['fwhm'].value), + 'seeing': float(tab.meta['seeing'].value), 'references_detected': int(np.sum(indx_ref)), + 'used_for_epsf': int(np.sum(tab['used_for_epsf'])), 'faintest_reference_detected': frd, + 'pipeline_version': tab.meta['version'], 'latest_version': new_version} + + db.cursor.execute("""INSERT INTO flows.photometry_details ( fileid_phot, fileid_img, fileid_template, @@ -628,9 +624,9 @@ def ingest_photometry_from_inbox(): %(pipeline_version)s );""", phot_summary) - db.cursor.execute("SELECT * FROM flows.photometry_summary WHERE fileid_img=%s;", [fileid_img]) - if db.cursor.fetchone() is None: - db.cursor.execute("""INSERT INTO flows.photometry_summary ( + db.cursor.execute("SELECT * FROM flows.photometry_summary WHERE fileid_img=%s;", [fileid_img]) + if db.cursor.fetchone() is None: + db.cursor.execute("""INSERT INTO flows.photometry_summary ( fileid_phot, fileid_img, fileid_template, @@ -659,8 +655,8 @@ def ingest_photometry_from_inbox(): %(pipeline_version)s, %(latest_version)s );""", phot_summary) - else: - db.cursor.execute("""UPDATE flows.photometry_summary SET + else: + db.cursor.execute("""UPDATE flows.photometry_summary SET fileid_phot=%(fileid_phot)s, targetid=%(targetid)s, fileid_template=%(fileid_template)s, @@ -675,93 +671,99 @@ def ingest_photometry_from_inbox(): latest_version=%(latest_version)s WHERE fileid_img=%(fileid_img)s;""", phot_summary) - # Update the photometry status to done: - db.cursor.execute("UPDATE flows.photometry_status SET status='done' WHERE fileid=%(fileid_img)s AND status='ingest';", phot_summary) - - if uploadlogid: - db.cursor.execute("UPDATE flows.uploadlog SET fileid=%s,status='ok' WHERE logid=%s;", [fileid, uploadlogid]) - - db.conn.commit() - - except: # noqa: E722, pragma: no cover - db.conn.rollback() - if newpath is not None and os.path.isdir(os.path.dirname(newpath)): - shutil.rmtree(os.path.dirname(newpath)) - raise - else: - # Set file and directory permissions: - # TODO: Can this not be handled in a more elegant way? - os.chmod(os.path.join(rootdir_archive, targetname), 0o2750) - os.chmod(os.path.join(rootdir_archive, targetname, f'{fileid_img:05d}'), 0o2750) - os.chmod(os.path.join(rootdir_archive, targetname, f'{fileid_img:05d}', f'v{new_version:02d}'), 0o2550) - for f in os.listdir(os.path.dirname(newpath)): - os.chmod(os.path.join(os.path.dirname(newpath), f), 0o0440) - - logger.info("DELETE THE ORIGINAL FILE") - if os.path.isfile(fpath): - os.remove(fpath) - if uploadlogid: - db.cursor.execute("UPDATE flows.uploadlog SET uploadpath=NULL WHERE logid=%s;", [uploadlogid]) - db.conn.commit() - -#-------------------------------------------------------------------------------------------------- + # Update the photometry status to done: + db.cursor.execute( + "UPDATE flows.photometry_status SET status='done' WHERE fileid=%(fileid_img)s AND status='ingest';", + phot_summary) + + if uploadlogid: + db.cursor.execute("UPDATE flows.uploadlog SET fileid=%s,status='ok' WHERE logid=%s;", + [fileid, uploadlogid]) + + db.conn.commit() + + except: # noqa: E722, pragma: no cover + db.conn.rollback() + if newpath is not None and os.path.isdir(os.path.dirname(newpath)): + shutil.rmtree(os.path.dirname(newpath)) + raise + else: + # Set file and directory permissions: + # TODO: Can this not be handled in a more elegant way? + os.chmod(os.path.join(rootdir_archive, targetname), 0o2750) + os.chmod(os.path.join(rootdir_archive, targetname, f'{fileid_img:05d}'), 0o2750) + os.chmod(os.path.join(rootdir_archive, targetname, f'{fileid_img:05d}', f'v{new_version:02d}'), 0o2550) + for f in os.listdir(os.path.dirname(newpath)): + os.chmod(os.path.join(os.path.dirname(newpath), f), 0o0440) + + logger.info("DELETE THE ORIGINAL FILE") + if os.path.isfile(fpath): + os.remove(fpath) + if uploadlogid: + db.cursor.execute("UPDATE flows.uploadlog SET uploadpath=NULL WHERE logid=%s;", [uploadlogid]) + db.conn.commit() + + +# -------------------------------------------------------------------------------------------------- def cleanup_inbox(): - """ - Cleanup of inbox directory - """ - rootdir_inbox = '/flows/inbox' - - # Just a simple check to begin with: - if not os.path.isdir(rootdir_inbox): - raise FileNotFoundError("INBOX could not be found.") - - # Remove empty directories: - for inputtype in ('science', 'templates', 'subtracted', 'photometry', 'replace'): - for dpath in glob.iglob(os.path.join(rootdir_inbox, '*', inputtype)): - if not os.listdir(dpath): - os.rmdir(dpath) - - for dpath in glob.iglob(os.path.join(rootdir_inbox, '*')): - if os.path.isdir(dpath) and not os.listdir(dpath): - os.rmdir(dpath) - - # Delete left-over files in the database tables, that have been removed from disk: - with AADC_DB() as db: - db.cursor.execute("SELECT logid,uploadpath FROM flows.uploadlog WHERE uploadpath IS NOT NULL;") - for row in db.cursor.fetchall(): - if not os.path.isfile(os.path.join(rootdir_inbox, row['uploadpath'])): - print("MARK AS DELETED IN DATABASE: " + row['uploadpath']) - db.cursor.execute("UPDATE flows.uploadlog SET uploadpath=NULL,status='File deleted' WHERE logid=%s;", [row['logid']]) - db.conn.commit() - -#-------------------------------------------------------------------------------------------------- + """ + Cleanup of inbox directory + """ + rootdir_inbox = '/flows/inbox' + + # Just a simple check to begin with: + if not os.path.isdir(rootdir_inbox): + raise FileNotFoundError("INBOX could not be found.") + + # Remove empty directories: + for inputtype in ('science', 'templates', 'subtracted', 'photometry', 'replace'): + for dpath in glob.iglob(os.path.join(rootdir_inbox, '*', inputtype)): + if not os.listdir(dpath): + os.rmdir(dpath) + + for dpath in glob.iglob(os.path.join(rootdir_inbox, '*')): + if os.path.isdir(dpath) and not os.listdir(dpath): + os.rmdir(dpath) + + # Delete left-over files in the database tables, that have been removed from disk: + with AADC_DB() as db: + db.cursor.execute("SELECT logid,uploadpath FROM flows.uploadlog WHERE uploadpath IS NOT NULL;") + for row in db.cursor.fetchall(): + if not os.path.isfile(os.path.join(rootdir_inbox, row['uploadpath'])): + print("MARK AS DELETED IN DATABASE: " + row['uploadpath']) + db.cursor.execute("UPDATE flows.uploadlog SET uploadpath=NULL,status='File deleted' WHERE logid=%s;", + [row['logid']]) + db.conn.commit() + + +# -------------------------------------------------------------------------------------------------- if __name__ == '__main__': - logging_level = logging.INFO - - # Setup logging: - formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') - console = logging.StreamHandler(sys.stdout) - console.setFormatter(formatter) - logger = logging.getLogger(__name__) - if not logger.hasHandlers(): - logger.addHandler(console) - logger.setLevel(logging_level) - - # Add a CounterFilter to the logger, which will count the number of log-records - # being passed through the logger. Can be used to count the number of errors/warnings: - counter = CounterFilter() - logger.addFilter(counter) - - # Run the ingests and cleanup: - ingest_from_inbox() - ingest_photometry_from_inbox() - cleanup_inbox() - - # Check the number of errors or warnings issued, and convert these to a return-code: - logcounts = counter.counter - if logcounts.get('ERROR', 0) > 0 or logcounts.get('CRITICAL', 0) > 0: - sys.exit(4) - elif logcounts.get('WARNING', 0) > 0: - sys.exit(3) - sys.exit(0) + logging_level = logging.INFO + + # Setup logging: + formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') + console = logging.StreamHandler(sys.stdout) + console.setFormatter(formatter) + logger = logging.getLogger(__name__) + if not logger.hasHandlers(): + logger.addHandler(console) + logger.setLevel(logging_level) + + # Add a CounterFilter to the logger, which will count the number of log-records + # being passed through the logger. Can be used to count the number of errors/warnings: + counter = CounterFilter() + logger.addFilter(counter) + + # Run the ingests and cleanup: + ingest_from_inbox() + ingest_photometry_from_inbox() + cleanup_inbox() + + # Check the number of errors or warnings issued, and convert these to a return-code: + logcounts = counter.counter + if logcounts.get('ERROR', 0) > 0 or logcounts.get('CRITICAL', 0) > 0: + sys.exit(4) + elif logcounts.get('WARNING', 0) > 0: + sys.exit(3) + sys.exit(0) diff --git a/run_photometry.py b/run_photometry.py index 4614a6b..f7042ae 100644 --- a/run_photometry.py +++ b/run_photometry.py @@ -1,9 +1,5 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- """ -Run Flows photometry. - -.. codeauthor:: Rasmus Handberg +Run Flows photometry. Allows multithreaded operations to be run """ import argparse @@ -13,175 +9,171 @@ import shutil import functools import multiprocessing -from flows import api, photometry, load_config - -# -------------------------------------------------------------------------------------------------- -def process_fileid(fid, output_folder_root=None, attempt_imagematch=True, autoupload=False, - keep_diff_fixed=False, cm_timeout=None): - logger = logging.getLogger('flows') - logging.captureWarnings(True) - logger_warn = logging.getLogger('py.warnings') - formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s', "%Y-%m-%d %H:%M:%S") - - datafile = api.get_datafile(fid) - target_name = datafile['target_name'] - - # Folder to save output: - output_folder = os.path.join(output_folder_root, target_name, f'{fid:05d}') - - photfile = None - _filehandler = None - try: - # Set the status to indicate that we have started processing: - if autoupload: - api.set_photometry_status(fid, 'running') - - # Create the output directory if it doesn't exist: - os.makedirs(output_folder, exist_ok=True) - - # Also write any logging output to the - _filehandler = logging.FileHandler(os.path.join(output_folder, 'photometry.log'), mode='w') - _filehandler.setFormatter(formatter) - _filehandler.setLevel(logging.INFO) - logger.addHandler(_filehandler) - logger_warn.addHandler(_filehandler) - - photfile = photometry( - fileid=fid, - output_folder=output_folder, - attempt_imagematch=attempt_imagematch, - keep_diff_fixed=keep_diff_fixed, - cm_timeout=cm_timeout) - - except (SystemExit, KeyboardInterrupt): - logger.error("Aborted by user or system.") - if os.path.exists(output_folder): - shutil.rmtree(output_folder, ignore_errors=True) - photfile = None - if autoupload: - api.set_photometry_status(fid, 'abort') - - except: # noqa: E722, pragma: no cover - logger.exception("Photometry failed") - photfile = None - if autoupload: - api.set_photometry_status(fid, 'error') - - if _filehandler is not None: - logger.removeHandler(_filehandler) - logger_warn.removeHandler(_filehandler) - - if photfile is not None: - if autoupload: - api.upload_photometry(fid, delete_completed=True) - api.set_photometry_status(fid, 'ingest') - - return photfile - -# -------------------------------------------------------------------------------------------------- +from tendrils import api, utils +from flows import photometry + + +def process_fileid(fid, output_folder_root=None, attempt_imagematch=True, autoupload=False, keep_diff_fixed=False, + cm_timeout=None): + logger = logging.getLogger('flows') + logging.captureWarnings(True) + logger_warn = logging.getLogger('py.warnings') + formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s', "%Y-%m-%d %H:%M:%S") + + datafile = api.get_datafile(fid) + target_name = datafile['target_name'] + + # Folder to save output: + output_folder = os.path.join(output_folder_root, target_name, f'{fid:05d}') + + photfile = None + _filehandler = None + try: + # Set the status to indicate that we have started processing: + if autoupload: + api.set_photometry_status(fid, 'running') + + # Create the output directory if it doesn't exist: + os.makedirs(output_folder, exist_ok=True) + + # Also write any logging output to the + _filehandler = logging.FileHandler(os.path.join(output_folder, 'photometry.log'), mode='w') + _filehandler.setFormatter(formatter) + _filehandler.setLevel(logging.INFO) + logger.addHandler(_filehandler) + logger_warn.addHandler(_filehandler) + + photfile = photometry(fileid=fid, output_folder=output_folder, attempt_imagematch=attempt_imagematch, + keep_diff_fixed=keep_diff_fixed, cm_timeout=cm_timeout) + + except (SystemExit, KeyboardInterrupt): + logger.error("Aborted by user or system.") + if os.path.exists(output_folder): + shutil.rmtree(output_folder, ignore_errors=True) + photfile = None + if autoupload: + api.set_photometry_status(fid, 'abort') + + except: # noqa: E722, pragma: no cover + logger.exception("Photometry failed") + photfile = None + if autoupload: + api.set_photometry_status(fid, 'error') + + if _filehandler is not None: + logger.removeHandler(_filehandler) + logger_warn.removeHandler(_filehandler) + + if photfile is not None: + if autoupload: + api.upload_photometry(fid, delete_completed=True) + api.set_photometry_status(fid, 'ingest') + + return photfile + + def main(): - # Parse command line arguments: - parser = argparse.ArgumentParser(description='Run photometry pipeline.') - parser.add_argument('-d', '--debug', help='Print debug messages.', action='store_true') - parser.add_argument('-q', '--quiet', help='Only report warnings and errors.', action='store_true') - parser.add_argument('-o', '--overwrite', help='Overwrite existing results.', action='store_true') - - group = parser.add_argument_group('Selecting which files to process') - group.add_argument('--fileid', type=int, default=None, action='append', help="Process this file ID. Overrides all other filters.") - group.add_argument('--targetid', type=int, default=None, action='append', help="Only process files from this target.") - group.add_argument('--filter', type=str, default=None, choices=['missing', 'all', 'error']) - group.add_argument('--minversion', type=str, default=None, help="Include files not previously processed with at least this version.") - - group = parser.add_argument_group('Processing settings') - group.add_argument('--threads', type=int, default=1, help="Number of parallel threads to use.") - group.add_argument('--no-imagematch', action='store_true', help="Disable ImageMatch.") - group.add_argument('--autoupload', action='store_true', - help="Automatically upload completed photometry to Flows website. Only do this, if you know what you are doing!") - group.add_argument('--fixposdiff', action='store_true', - help="Fix SN position during PSF photometry of difference image. Useful when difference image is noisy.") - group.add_argument('--wcstimeout', type=int, default=None, help="Timeout in Seconds for WCS.") - args = parser.parse_args() - - # Ensure that all input has been given: - if not args.fileid and not args.targetid and args.filter is None: - parser.error("Please select either a specific FILEID .") - - # Set logging level: - logging_level = logging.INFO - if args.quiet: - logging_level = logging.WARNING - elif args.debug: - logging_level = logging.DEBUG - - # Number of threads to use: - threads = args.threads - if threads <= 0: - threads = multiprocessing.cpu_count() - - # Setup logging: - formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s', "%Y-%m-%d %H:%M:%S") - console = logging.StreamHandler(sys.stdout) - console.setFormatter(formatter) - logger = logging.getLogger('flows') - if not logger.hasHandlers(): - logger.addHandler(console) - logger.propagate = False - logger.setLevel(logging_level) - - # If we have started a new processing, perform a cleanup of the - # photometry status indicator. This will change all processes - # still marked as "running" to "abort" if they have been running - # for more than a day: - if args.autoupload: - api.cleanup_photometry_status() - - if args.fileid is not None: - # Run the specified fileids: - fileids = args.fileid - else: - # Ask the API for a list of fileids which are yet to be processed: - if args.targetid is not None: - fileids = [] - for targid in args.targetid: - fileids += api.get_datafiles(targetid=targid, filt=args.filter, minversion=args.minversion) - else: - fileids = api.get_datafiles(filt=args.filter, minversion=args.minversion) - - # Remove duplicates from fileids to be processed: - fileids = list(set(fileids)) - - # Ask the config where we should store the output: - config = load_config() - output_folder_root = config.get('photometry', 'output', fallback='.') - - # Create function wrapper: - process_fileid_wrapper = functools.partial( - process_fileid, - output_folder_root=output_folder_root, - attempt_imagematch=not args.no_imagematch, - autoupload=args.autoupload, - keep_diff_fixed=args.fixposdiff, - cm_timeout=args.wcstimeout) - - if threads > 1: - # Disable printing info messages from the parent function. - # It is going to be all jumbled up anyway. - logger.setLevel(logging.WARNING) - - # There is more than one area to process, so let's start - # a process pool and process them in parallel: - with multiprocessing.Pool(threads) as pool: - for res in pool.imap_unordered(process_fileid_wrapper, fileids): - pass - - else: - # Only single thread so simply run it directly: - for fid in fileids: - logger.info("=" * 72) - logger.info(fid) - logger.info("=" * 72) - process_fileid_wrapper(fid) - -#-------------------------------------------------------------------------------------------------- + # Parse command line arguments: + parser = argparse.ArgumentParser(description='Run photometry pipeline.') + parser.add_argument('-d', '--debug', help='Print debug messages.', action='store_true') + parser.add_argument('-q', '--quiet', help='Only report warnings and errors.', action='store_true') + parser.add_argument('-o', '--overwrite', help='Overwrite existing results.', action='store_true') + + group = parser.add_argument_group('Selecting which files to process') + group.add_argument('--fileid', type=int, default=None, action='append', + help="Process this file ID. Overrides all other filters.") + group.add_argument('--targetid', type=int, default=None, action='append', + help="Only process files from this target.") + group.add_argument('--filter', type=str, default=None, choices=['missing', 'all', 'error']) + group.add_argument('--minversion', type=str, default=None, + help="Include files not previously processed with at least this version.") + + group = parser.add_argument_group('Processing settings') + group.add_argument('--threads', type=int, default=1, help="Number of parallel threads to use.") + group.add_argument('--no-imagematch', action='store_true', help="Disable ImageMatch.") + group.add_argument('--autoupload', action='store_true', + help="Automatically upload completed photometry to Flows website. Only do this, if you know what you are doing!") + group.add_argument('--fixposdiff', action='store_true', + help="Fix SN position during PSF photometry of difference image. Useful when difference image is noisy.") + group.add_argument('--wcstimeout', type=int, default=None, help="Timeout in Seconds for WCS.") + args = parser.parse_args() + + # Ensure that all input has been given: + if not args.fileid and not args.targetid and args.filter is None: + parser.error("Please select either a specific FILEID .") + + # Set logging level: + logging_level = logging.INFO + if args.quiet: + logging_level = logging.WARNING + elif args.debug: + logging_level = logging.DEBUG + + # Number of threads to use: + threads = args.threads + if threads <= 0: + threads = multiprocessing.cpu_count() + + # Setup logging: + formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s', "%Y-%m-%d %H:%M:%S") + console = logging.StreamHandler(sys.stdout) + console.setFormatter(formatter) + logger = logging.getLogger('flows') + if not logger.hasHandlers(): + logger.addHandler(console) + logger.propagate = False + logger.setLevel(logging_level) + + # If we have started a new processing, perform a cleanup of the + # photometry status indicator. This will change all processes + # still marked as "running" to "abort" if they have been running + # for more than a day: + if args.autoupload: + api.cleanup_photometry_status() + + if args.fileid is not None: + # Run the specified fileids: + fileids = args.fileid + else: + # Ask the API for a list of fileids which are yet to be processed: + if args.targetid is not None: + fileids = [] + for targid in args.targetid: + fileids += api.get_datafiles(targetid=targid, filt=args.filter, minversion=args.minversion) + else: + fileids = api.get_datafiles(filt=args.filter, minversion=args.minversion) + + # Remove duplicates from fileids to be processed: + fileids = list(set(fileids)) + + # Ask the config where we should store the output: + config = utils.load_config() + output_folder_root = config.get('photometry', 'output', fallback='.') + + # Create function wrapper: + process_fileid_wrapper = functools.partial(process_fileid, output_folder_root=output_folder_root, + attempt_imagematch=not args.no_imagematch, autoupload=args.autoupload, + keep_diff_fixed=args.fixposdiff, cm_timeout=args.wcstimeout) + + if threads > 1: + # Disable printing info messages from the parent function. + # It is going to be all jumbled up anyway. + logger.setLevel(logging.WARNING) + + # There is more than one area to process, so let's start + # a process pool and process them in parallel: + with multiprocessing.Pool(threads) as pool: + for res in pool.imap_unordered(process_fileid_wrapper, fileids): + pass + + else: + # Only single thread so simply run it directly: + for fid in fileids: + logger.info("=" * 72) + logger.info(fid) + logger.info("=" * 72) + process_fileid_wrapper(fid) + + if __name__ == '__main__': - main() + main() diff --git a/run_plotlc.py b/run_plotlc.py index 519fbfd..c058e6b 100644 --- a/run_plotlc.py +++ b/run_plotlc.py @@ -2,7 +2,7 @@ # -*- coding: utf-8 -*- """ Plot photometry for target, loaded from local photometry working directory. - +@TODO:Refactor out of FLOWS pipeline into flows-tools .. codeauthor:: Emir K .. codeauthor:: Rasmus Handberg """ @@ -14,135 +14,135 @@ from astropy.table import Table from astropy.time import Time from flows.plots import plt, plots_interactive -from flows import api, load_config +from tendrils import api, utils import mplcursors import seaborn as sns -#-------------------------------------------------------------------------------------------------- + def main(): - # All available filters: - all_filters = list(api.get_filters().keys()) + # All available filters: + all_filters = list(api.get_filters().keys()) - # Parser: - parser = argparse.ArgumentParser(description='Plot photometry for target') - parser.add_argument('--target', '-t', type=str, required=True, help="""Target identifier: + # Parser: + parser = argparse.ArgumentParser(description='Plot photometry for target') + parser.add_argument('--target', '-t', type=str, required=True, help="""Target identifier: Can be either the SN name (e.g. 2019yvr) or the Flows target ID.""") - parser.add_argument('--fileid', '-i', type=int, default=None, action='append', help='Specific file ids.') - parser.add_argument('--filter', '-f', type=str, default=None, action='append', choices=all_filters, - help=f'Photmetric filter to plot. If not provided will plot all. Choose between {all_filters}') - parser.add_argument('--offset', '-jd', type=float, default=2458800.0) - parser.add_argument('--subonly', action='store_true', help='Only show template subtracted data points.') - args = parser.parse_args() - - # To use when only plotting some filters - usefilts = args.filters - if usefilts is not None: - usefilts = set(args.filters) - - # To use when only plotting some fileids - # Parse input fileids: - if args.fileid is not None: - # Plot the specified fileid: - fileids = args.fileid - else: - fileids = [] - if len(fileids) > 1: - raise NotImplementedError("This has not been implemented yet") - - # Get the name of the target: - snname = args.target - if snname.isdigit(): - datafiles = api.get_datafiles(int(snname), filt='all') - snname = api.get_datafile(datafiles[0])['target_name'] - - # Change to directory, raise if it does not exist - config = load_config() - workdir_root = config.get('photometry', 'output', fallback='.') - sndir = os.path.join(workdir_root, snname) - if not os.path.isdir(sndir): - print('No such directory as',sndir) - return - - # Get list of photometry files - phot_files = glob.iglob(os.path.join(sndir, '*', 'photometry.ecsv')) - - # Load all data into astropy table - tablerows = [] - for file in phot_files: - # Load photometry file into Table: - AT = Table.read(file, format='ascii.ecsv') - - # Pull out meta-data: - fileid = AT.meta['fileid'] - filt = AT.meta['photfilter'] - jd = Time(AT.meta['obstime-bmjd'], format='mjd', scale='tdb').jd - - # get phot of diff image - AT.add_index('starid') - if -1 in AT['starid']: - mag, mag_err = AT.loc[-1]['mag'], AT.loc[-1]['mag_error'] - sub = True - elif 0 in AT['starid']: - print('No subtraction found for:',file,'in filter',filt) - mag,mag_err = AT.loc[0]['mag'],AT.loc[0]['mag_error'] - sub = False - else: - print('No object phot found, skipping: \n',file) - continue - - tablerows.append((jd, mag, mag_err, filt, sub, fileid)) - - phot = Table( - rows=tablerows, - names=['jd','mag','mag_err','filter','sub','fileid'], - dtype=['float64','float64','float64','S64','bool','int64']) - - # Create list of filters to plot: - filters = list(np.unique(phot['filter'])) - if usefilts: - filters = set(filters).intersection(usefilts) - - # Split photometry table - shifts = dict(zip(filters, np.zeros(len(filters)))) - - # Create the plot: - plots_interactive() - sns.set(style='ticks') - dpi_mult = 1 if not args.subonly else 2 - fig, ax = plt.subplots(figsize=(6.4,4), dpi=130*dpi_mult) - fig.subplots_adjust(top=0.95, left=0.1, bottom=0.1, right=0.97) - - cps = sns.color_palette() - colors = dict(zip(filters,(cps[2],cps[3],cps[0],cps[-1],cps[1]))) - - if args.subonly: - for filt in filters: - lc = phot[(phot['filter'] == filt) & phot['sub']] - ax.errorbar(lc['jd'] - args.offset, lc['mag'] + shifts[filt], lc['mag_err'], - marker='s', linestyle='None', label=filt, color=colors[filt]) - - else: - for filt in filters: - lc = phot[phot['filter'] == filt] - ax.errorbar(lc['jd'] - args.offset, lc['mag'] + shifts[filt], lc['mag_err'], - marker='s', linestyle='None', label=filt, color=colors[filt]) - - ax.invert_yaxis() - ax.legend() - ax.set_xlabel('JD - ' + str(args.offset), fontsize=16) - ax.set_ylabel('App. Mag', fontsize=16) - ax.set_title(snname) - - # Make the points interactive: - def annotate(sel): - lc = phot[phot['filter'] == str(sel.artist.get_label())] - point = lc[sel.target.index] - point = dict(zip(point.colnames, point)) # Convert table row to dict - return sel.annotation.set_text("Fileid: {fileid:d}\nJD: {jd:.3f}\nMag: {mag:.2f}$\\pm${mag_err:.2f}".format(**point)) - - mplcursors.cursor(ax).connect("add", annotate) - plt.show(block=True) - -#-------------------------------------------------------------------------------------------------- + parser.add_argument('--fileid', '-i', type=int, default=None, action='append', help='Specific file ids.') + parser.add_argument('--filter', '-f', type=str, default=None, action='append', choices=all_filters, + help=f'Photmetric filter to plot. If not provided will plot all. Choose between {all_filters}') + parser.add_argument('--offset', '-jd', type=float, default=2458800.0) + parser.add_argument('--subonly', action='store_true', help='Only show template subtracted data points.') + args = parser.parse_args() + + # To use when only plotting some filters + usefilts = args.filters + if usefilts is not None: + usefilts = set(args.filters) + + # To use when only plotting some fileids + # Parse input fileids: + if args.fileid is not None: + # Plot the specified fileid: + fileids = args.fileid + else: + fileids = [] + if len(fileids) > 1: + raise NotImplementedError("This has not been implemented yet") + + # Get the name of the target: + snname = args.target + if snname.isdigit(): + datafiles = api.get_datafiles(int(snname), filt='all') + snname = api.get_datafile(datafiles[0])['target_name'] + + # Change to directory, raise if it does not exist + config = utils.load_config() + workdir_root = config.get('photometry', 'output', fallback='.') + sndir = os.path.join(workdir_root, snname) + if not os.path.isdir(sndir): + print('No such directory as', sndir) + return + + # Get list of photometry files + phot_files = glob.iglob(os.path.join(sndir, '*', 'photometry.ecsv')) + + # Load all data into astropy table + tablerows = [] + for file in phot_files: + # Load photometry file into Table: + AT = Table.read(file, format='ascii.ecsv') + + # Pull out meta-data: + fileid = AT.meta['fileid'] + filt = AT.meta['photfilter'] + jd = Time(AT.meta['obstime-bmjd'], format='mjd', scale='tdb').jd + + # get phot of diff image + AT.add_index('starid') + if -1 in AT['starid']: + mag, mag_err = AT.loc[-1]['mag'], AT.loc[-1]['mag_error'] + sub = True + elif 0 in AT['starid']: + print('No subtraction found for:', file, 'in filter', filt) + mag, mag_err = AT.loc[0]['mag'], AT.loc[0]['mag_error'] + sub = False + else: + print('No object phot found, skipping: \n', file) + continue + + tablerows.append((jd, mag, mag_err, filt, sub, fileid)) + + phot = Table(rows=tablerows, names=['jd', 'mag', 'mag_err', 'filter', 'sub', 'fileid'], + dtype=['float64', 'float64', 'float64', 'S64', 'bool', 'int64']) + + # Create list of filters to plot: + filters = list(np.unique(phot['filter'])) + if usefilts: + filters = set(filters).intersection(usefilts) + + # Split photometry table + shifts = dict(zip(filters, np.zeros(len(filters)))) + + # Create the plot: + plots_interactive() + sns.set(style='ticks') + dpi_mult = 1 if not args.subonly else 2 + fig, ax = plt.subplots(figsize=(6.4, 4), dpi=130 * dpi_mult) + fig.subplots_adjust(top=0.95, left=0.1, bottom=0.1, right=0.97) + + cps = sns.color_palette() + colors = dict(zip(filters, (cps[2], cps[3], cps[0], cps[-1], cps[1]))) + + if args.subonly: + for filt in filters: + lc = phot[(phot['filter'] == filt) & phot['sub']] + ax.errorbar(lc['jd'] - args.offset, lc['mag'] + shifts[filt], lc['mag_err'], marker='s', linestyle='None', + label=filt, color=colors[filt]) + + else: + for filt in filters: + lc = phot[phot['filter'] == filt] + ax.errorbar(lc['jd'] - args.offset, lc['mag'] + shifts[filt], lc['mag_err'], marker='s', linestyle='None', + label=filt, color=colors[filt]) + + ax.invert_yaxis() + ax.legend() + ax.set_xlabel('JD - ' + str(args.offset), fontsize=16) + ax.set_ylabel('App. Mag', fontsize=16) + ax.set_title(snname) + + # Make the points interactive: + def annotate(sel): + lc = phot[phot['filter'] == str(sel.artist.get_label())] + point = lc[sel.target.index] + point = dict(zip(point.colnames, point)) # Convert table row to dict + return sel.annotation.set_text( + "Fileid: {fileid:d}\nJD: {jd:.3f}\nMag: {mag:.2f}$\\pm${mag_err:.2f}".format(**point)) + + mplcursors.cursor(ax).connect("add", annotate) + plt.show(block=True) + + +# -------------------------------------------------------------------------------------------------- if __name__ == '__main__': - main() + main() diff --git a/run_querytns.py b/run_querytns.py index 1b11a3c..728a4e7 100644 --- a/run_querytns.py +++ b/run_querytns.py @@ -4,7 +4,7 @@ Query TNS for new targets and upload to candidate marshal. https://wis-tns.weizmann.ac.il/ TNS bot apikey must exist in config - +@TODO: Move to flows API .. codeauthor:: Emir Karamehmetoglu .. codeauthor:: Rasmus Handberg """ @@ -16,119 +16,115 @@ from astropy.coordinates import SkyCoord from astropy.time import Time from datetime import datetime, timedelta, timezone -from flows import api, tns +from tendrils import api +from tendrils.utils import load_tns_config, tns_getnames, TNSConfigError, tns_get_obj + -#-------------------------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------------------------- def main(): - # Parse command line arguments: - parser = argparse.ArgumentParser(description='Query TNS and upload to Flows candidates.') - parser.add_argument('-d', '--debug', action='store_true', help='Print debug messages.') - parser.add_argument('-q', '--quiet', action='store_true', help='Only report warnings and errors.') - parser.add_argument('--zmax', type=float, default=0.105, help='Maximum redshift.') - parser.add_argument('--zmin', type=float, default=0.000000001, help='Minimum redshift.') - parser.add_argument('-b', '--days_begin', type=int, default=30, help='Discovery day at least X days before today.') - parser.add_argument('-e', '--days_end', type=int, default=3, help='Discovery day at most X days before today.') - parser.add_argument('-o', '--objtype', type=str, default=[3, 104], help='TNS objtype int given as comma separed string with no spaces') - parser.add_argument('-m', '--limit_months', type=int, default=2, help='Integer number of months to limit TNS search (for speed). \ + # Parse command line arguments: + parser = argparse.ArgumentParser(description='Query TNS and upload to Flows candidates.') + parser.add_argument('-d', '--debug', action='store_true', help='Print debug messages.') + parser.add_argument('-q', '--quiet', action='store_true', help='Only report warnings and errors.') + parser.add_argument('--zmax', type=float, default=0.105, help='Maximum redshift.') + parser.add_argument('--zmin', type=float, default=0.000000001, help='Minimum redshift.') + parser.add_argument('-b', '--days_begin', type=int, default=30, help='Discovery day at least X days before today.') + parser.add_argument('-e', '--days_end', type=int, default=3, help='Discovery day at most X days before today.') + parser.add_argument('-o', '--objtype', type=str, default=[3, 104], + help='TNS objtype int given as comma separed string with no spaces') + parser.add_argument('-m', '--limit_months', type=int, default=2, help='Integer number of months to limit TNS search (for speed). \ Should be greater than days_begin.') - parser.add_argument('--autoupload', action='store_true', help="Automatically upload targets to Flows website. Only do this, if you know what you are doing!") - args = parser.parse_args() - - # Set logging level: - logging_level = logging.INFO - if args.quiet: - logging_level = logging.WARNING - elif args.debug: - logging_level = logging.DEBUG - - # Setup logging: - formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') - console = logging.StreamHandler() - console.setFormatter(formatter) - logger = logging.getLogger(__name__) - if not logger.hasHandlers(): - logger.addHandler(console) - logger.setLevel(logging_level) - - tqdm_settings = {'disable': None if logger.isEnabledFor(logging.INFO) else True} - - # Try to load TNS config - only used for early stopping - try: - tns._load_tns_config() - except tns.TNSConfigError: - parser.error("Error in TNS configuration.") - return - - # Calculate current date and date range to search: - date_now = datetime.now(timezone.utc).date() - date_end = date_now - timedelta(days=args.days_end) - date_begin = date_now - timedelta(days=args.days_begin) - logger.info('Date begin = %s, date_end = %s', date_begin, date_end) - - # Query TNS for SN names - logger.info('Querying TNS for all targets, this may take awhile') - nms = tns.tns_getnames( - months=args.limit_months, # pre-limit TNS search to candidates reported in the last X months - date_begin=date_begin, - date_end=date_end, - zmin=args.zmin, - zmax=args.zmax, - objtype=args.objtype # Relevant TNS SN Ia subtypes. - ) - logger.debug(nms) - - if not nms: - logger.info("No targets were found.") - return - - # Remove already existing names using flows api - included_names = ['SN' + target['target_name'] for target in api.get_targets()] - nms = list(set(nms) - set(included_names)) - logger.info('Target names obtained: %s', nms) - - # Regular Expression matching any string starting with "ztf" - regex_ztf = re.compile('^ztf', flags=re.IGNORECASE) - regex_sn = re.compile(r'^sn\s*', flags=re.IGNORECASE) - - # Query TNS for object info using API, then upload to FLOWS using API. - num_uploaded = 0 - if args.autoupload: - for name in tqdm(nms, **tqdm_settings): - sn = regex_sn.sub('', name) - logger.debug('querying TNS for: %s', sn) - - # make GET request to TNS via API - reply = tns.tns_get_obj(sn) - - # Parse output - if reply: - logger.debug('GET query successful') - - # Extract object info - coord = SkyCoord(ra=reply['radeg'], dec=reply['decdeg'], unit='deg', frame='icrs') - discovery_date = Time(reply['discoverydate'], format='iso', scale='utc') - ztf = list(filter(regex_ztf.match, reply['internal_names'])) - ztf = None if not ztf else ztf[0] - if 'object_type' in reply and 'name' in reply['object_type']: - sntype = regex_sn.sub('', reply['object_type']['name']) - else: - sntype = None - - # Try to upload to FLOWS - newtargetid = api.add_target(reply['objname'], coord, - redshift=reply['redshift'], - discovery_date=discovery_date, - discovery_mag=reply['discoverymag'], - host_galaxy=reply['hostname'], - ztf=ztf, - sntype=sntype, - status='candidate', - project='flows') - logger.debug('Uploaded to FLOWS with targetid=%d', newtargetid) - num_uploaded += 1 - - logger.info("%d targets uploaded to Flows.", num_uploaded) - -#-------------------------------------------------------------------------------------------------- + parser.add_argument('--autoupload', action='store_true', + help="Automatically upload targets to Flows website. Only do this, if you know what you are doing!") + args = parser.parse_args() + + # Set logging level: + logging_level = logging.INFO + if args.quiet: + logging_level = logging.WARNING + elif args.debug: + logging_level = logging.DEBUG + + # Setup logging: + formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') + console = logging.StreamHandler() + console.setFormatter(formatter) + logger = logging.getLogger(__name__) + if not logger.hasHandlers(): + logger.addHandler(console) + logger.setLevel(logging_level) + + tqdm_settings = {'disable': None if logger.isEnabledFor(logging.INFO) else True} + + # Try to load TNS config - only used for early stopping + try: + load_tns_config() + except TNSConfigError: + parser.error("Error in TNS configuration.") + return + + # Calculate current date and date range to search: + date_now = datetime.now(timezone.utc).date() + date_end = date_now - timedelta(days=args.days_end) + date_begin = date_now - timedelta(days=args.days_begin) + logger.info('Date begin = %s, date_end = %s', date_begin, date_end) + + # Query TNS for SN names + logger.info('Querying TNS for all targets, this may take awhile') + nms = tns_getnames(months=args.limit_months, # pre-limit TNS search to candidates reported in the last X months + date_begin=date_begin, date_end=date_end, zmin=args.zmin, zmax=args.zmax, + objtype=args.objtype# Relevant TNS SN Ia subtypes. + ) + logger.debug(nms) + + if not nms: + logger.info("No targets were found.") + return + + # Remove already existing names using flows api + included_names = ['SN' + target['target_name'] for target in api.get_targets()] + nms = list(set(nms) - set(included_names)) + logger.info('Target names obtained: %s', nms) + + # Regular Expression matching any string starting with "ztf" + regex_ztf = re.compile('^ztf', flags=re.IGNORECASE) + regex_sn = re.compile(r'^sn\s*', flags=re.IGNORECASE) + + # Query TNS for object info using API, then upload to FLOWS using API. + num_uploaded = 0 + if args.autoupload: + for name in tqdm(nms, **tqdm_settings): + sn = regex_sn.sub('', name) + logger.debug('querying TNS for: %s', sn) + + # make GET request to TNS via API + reply = tns_get_obj(sn) + + # Parse output + if reply: + logger.debug('GET query successful') + + # Extract object info + coord = SkyCoord(ra=reply['radeg'], dec=reply['decdeg'], unit='deg', frame='icrs') + discovery_date = Time(reply['discoverydate'], format='iso', scale='utc') + ztf = list(filter(regex_ztf.match, reply['internal_names'])) + ztf = None if not ztf else ztf[0] + if 'object_type' in reply and 'name' in reply['object_type']: + sntype = regex_sn.sub('', reply['object_type']['name']) + else: + sntype = None + + # Try to upload to FLOWS + newtargetid = api.add_target(reply['objname'], coord, redshift=reply['redshift'], + discovery_date=discovery_date, discovery_mag=reply['discoverymag'], + host_galaxy=reply['hostname'], ztf=ztf, sntype=sntype, status='candidate', + project='flows') + logger.debug('Uploaded to FLOWS with targetid=%d', newtargetid) + num_uploaded += 1 + + logger.info("%d targets uploaded to Flows.", num_uploaded) + + +# -------------------------------------------------------------------------------------------------- if __name__ == '__main__': - main() + main() diff --git a/run_upload_photometry.py b/run_upload_photometry.py index e4f859c..f3aa757 100644 --- a/run_upload_photometry.py +++ b/run_upload_photometry.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- """ Upload photometry results to Flows server. @@ -8,37 +6,38 @@ import argparse import logging -from flows import api +from tendrils import api + -#-------------------------------------------------------------------------------------------------- def main(): - # Parse command line arguments: - parser = argparse.ArgumentParser(description='Upload photometry.') - parser.add_argument('-d', '--debug', help='Print debug messages.', action='store_true') - parser.add_argument('-q', '--quiet', help='Only report warnings and errors.', action='store_true') - parser.add_argument('fileids', type=int, help='File IDs to be uploaded.', nargs='+') - args = parser.parse_args() - - # Set logging level: - logging_level = logging.INFO - if args.quiet: - logging_level = logging.WARNING - elif args.debug: - logging_level = logging.DEBUG - - # Setup logging: - formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') - console = logging.StreamHandler() - console.setFormatter(formatter) - logger = logging.getLogger('flows') - if not logger.hasHandlers(): - logger.addHandler(console) - logger.setLevel(logging_level) - - # Loop through the fileids and upload the results: - for fid in args.fileids: - api.upload_photometry(fid) - -#-------------------------------------------------------------------------------------------------- + # Parse command line arguments: + parser = argparse.ArgumentParser(description='Upload photometry.') + parser.add_argument('-d', '--debug', help='Print debug messages.', action='store_true') + parser.add_argument('-q', '--quiet', help='Only report warnings and errors.', action='store_true') + parser.add_argument('fileids', type=int, help='File IDs to be uploaded.', nargs='+') + args = parser.parse_args() + + # Set logging level: + logging_level = logging.INFO + if args.quiet: + logging_level = logging.WARNING + elif args.debug: + logging_level = logging.DEBUG + + # Setup logging: + formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') + console = logging.StreamHandler() + console.setFormatter(formatter) + logger = logging.getLogger('flows') + if not logger.hasHandlers(): + logger.addHandler(console) + logger.setLevel(logging_level) + + # Loop through the fileids and upload the results: + for fid in args.fileids: + api.upload_photometry(fid) + + +# -------------------------------------------------------------------------------------------------- if __name__ == '__main__': - main() + main() diff --git a/run_visibility.py b/run_visibility.py index 574de4f..05934ab 100644 --- a/run_visibility.py +++ b/run_visibility.py @@ -9,15 +9,15 @@ import flows if __name__ == '__main__': - # Parse command line arguments: - parser = argparse.ArgumentParser(description='Run photometry pipeline.') - parser.add_argument('-t', '--target', type=str, help='TIC identifier of target.', nargs='?', default=2) - parser.add_argument('-s', '--site', type=int, help='TIC identifier of target.', nargs='?', default=None) - parser.add_argument('-d', '--date', type=str, help='TIC identifier of target.', nargs='?', default=None) - parser.add_argument('-o', '--output', type=str, help='TIC identifier of target.', nargs='?', default=None) - args = parser.parse_args() + # Parse command line arguments: + parser = argparse.ArgumentParser(description='Run photometry pipeline.') + parser.add_argument('-t', '--target', type=str, help='TIC identifier of target.', nargs='?', default=2) + parser.add_argument('-s', '--site', type=int, help='TIC identifier of target.', nargs='?', default=None) + parser.add_argument('-d', '--date', type=str, help='TIC identifier of target.', nargs='?', default=None) + parser.add_argument('-o', '--output', type=str, help='TIC identifier of target.', nargs='?', default=None) + args = parser.parse_args() - if args.output is None: - plots_interactive() + if args.output is None: + plots_interactive() - flows.visibility(target=args.target, siteid=args.site, date=args.date, output=args.output) + flows.visibility(target=args.target, siteid=args.site, date=args.date, output=args.output) diff --git a/setup.cfg b/setup.cfg index 4af7b14..094903d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,16 +1,13 @@ [flake8] exclude = .git,__pycache__,notes -max-line-length = 99 +# To be compliant with black +max-line-length = 120 +#To be compliant with black +extend-ignore = E203 # Enable flake8-logging-format: enable-extensions = G -# Configuration of flake8-tabs: -use-flake8-tabs = True -blank-lines-indent = never -indent-tabs-def = 1 -indent-style = tab - ignore = E117, # over-indented (set when using tabs) E127, # continuation line over-indented for visual indent diff --git a/tests/conftest.py b/tests/conftest.py index 89638b0..d6d8345 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,60 +9,56 @@ import pytest import sys import os -#import shutil import configparser import subprocess import shlex if sys.path[0] != os.path.abspath(os.path.join(os.path.dirname(__file__), '..')): - sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) -#-------------------------------------------------------------------------------------------------- + +# -------------------------------------------------------------------------------------------------- def capture_cli(script, params=[], mpiexec=False): + if isinstance(params, str): + params = shlex.split(params) - if isinstance(params, str): - params = shlex.split(params) + cmd = [sys.executable, script.strip()] + list(params) + if mpiexec: + cmd = ['mpiexec', '-n', '2'] + cmd - cmd = [sys.executable, script.strip()] + list(params) - if mpiexec: - cmd = ['mpiexec', '-n', '2'] + cmd + print(f"Command: {cmd}") + proc = subprocess.Popen(cmd, cwd=os.path.join(os.path.dirname(__file__), '..'), stdout=subprocess.PIPE, + stderr=subprocess.PIPE, universal_newlines=True) + out, err = proc.communicate() + exitcode = proc.returncode + proc.kill() - print(f"Command: {cmd}") - proc = subprocess.Popen(cmd, - cwd=os.path.join(os.path.dirname(__file__), '..'), - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - universal_newlines=True - ) - out, err = proc.communicate() - exitcode = proc.returncode - proc.kill() + print(f"ExitCode: {exitcode:d}") + print("StdOut:\n%s" % out.strip()) + print("StdErr:\n%s" % err.strip()) + return out, err, exitcode - print(f"ExitCode: {exitcode:d}") - print("StdOut:\n%s" % out.strip()) - print("StdErr:\n%s" % err.strip()) - return out, err, exitcode -#-------------------------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------------------------- @pytest.fixture(scope='session') def SETUP_CONFIG(): - """ - Fixture which sets up a dummy config-file which allows for simple testing only. - """ - config_file = os.path.join(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'flows', 'config.ini'))) - if os.path.exists(config_file): - yield config_file - else: - confstr = os.environ.get('FLOWS_CONFIG') - if confstr is None: - raise RuntimeError("Config file can not be set up.") + """ + Fixture which sets up a dummy config-file which allows for simple testing only. + """ + config_file = os.path.join(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'flows', 'config.ini'))) + if os.path.exists(config_file): + yield config_file + else: + confstr = os.environ.get('FLOWS_CONFIG') + if confstr is None: + raise RuntimeError("Config file can not be set up.") - # Write minimal config file that can be used for testing: - config = configparser.ConfigParser() - config.read_string(confstr) - with open(config_file, 'w') as fid: - config.write(fid) - fid.flush() + # Write minimal config file that can be used for testing: + config = configparser.ConfigParser() + config.read_string(confstr) + with open(config_file, 'w') as fid: + config.write(fid) + fid.flush() - yield config_file - os.remove(config_file) + yield config_file + os.remove(config_file) diff --git a/tests/test_api.py b/tests/test_api.py deleted file mode 100644 index c630097..0000000 --- a/tests/test_api.py +++ /dev/null @@ -1,170 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -Test API calls. - -.. codeauthor:: Rasmus Handberg -""" - -import pytest -import os.path -import tempfile -import numpy as np -from astropy.coordinates import EarthLocation -from astropy.table import Table -import conftest # noqa: F401 -from flows import api, load_config - -#-------------------------------------------------------------------------------------------------- -def test_api_get_targets(SETUP_CONFIG): - - tab = api.get_targets() - print(tab) - - assert isinstance(tab, list) - assert len(tab) > 0 - for target in tab: - assert isinstance(target, dict) - assert 'target_name' in target - assert 'targetid' in target - assert 'ra' in target - assert 'decl' in target - assert 'target_status' in target - -#-------------------------------------------------------------------------------------------------- -def test_api_get_target(SETUP_CONFIG): - - tab = api.get_target(2) - print(tab) - - assert isinstance(tab, dict) - assert tab['target_name'] == '2019yvr' - assert tab['targetid'] == 2 - assert tab['target_status'] == 'target' - assert tab['ztf_id'] == 'ZTF20aabqkxs' - -#-------------------------------------------------------------------------------------------------- -def test_api_get_datafiles(SETUP_CONFIG): - - tab = api.get_datafiles(targetid=2, filt='all') - print(tab) - assert isinstance(tab, list) - assert len(tab) > 0 - for fid in tab: - assert isinstance(fid, int) - - fileid = tab[0] - tab = api.get_datafile(fileid) - print(tab) - assert tab['fileid'] == fileid - assert tab['targetid'] == 2 - -#-------------------------------------------------------------------------------------------------- -def test_api_get_filters(SETUP_CONFIG): - - tab = api.get_filters() - print(tab) - assert isinstance(tab, dict) - for key, value in tab.items(): - assert isinstance(value, dict) - assert value['photfilter'] == key - assert 'wavelength_center' in value - -#-------------------------------------------------------------------------------------------------- -def test_api_get_sites(SETUP_CONFIG): - - tab = api.get_all_sites() - print(tab) - assert isinstance(tab, list) - assert len(tab) > 0 - for site in tab: - assert isinstance(site, dict) - assert isinstance(site['siteid'], int) - assert 'sitename' in site - assert isinstance(site['EarthLocation'], EarthLocation) - - site0 = tab[0] - print(site0) - tab = api.get_site(site0['siteid']) - print(tab) - assert isinstance(tab, dict) - assert tab == site0 - -#-------------------------------------------------------------------------------------------------- -def test_api_get_catalog(SETUP_CONFIG): - - cat = api.get_catalog(2, output='table') - print(cat) - - assert isinstance(cat, dict) - - target = cat['target'] - assert isinstance(target, Table) - assert len(target) == 1 - assert target['targetid'] == 2 - assert target['target_name'] == '2019yvr' - - ref = cat['references'] - assert isinstance(ref, Table) - - avoid = cat['avoid'] - assert isinstance(avoid, Table) - -#-------------------------------------------------------------------------------------------------- -def test_api_get_lightcurve(SETUP_CONFIG): - - tab = api.get_lightcurve(2) - print(tab) - - assert isinstance(tab, Table) - assert len(tab) > 0 - assert 'time' in tab.colnames - assert 'mag_raw' in tab.colnames - -#-------------------------------------------------------------------------------------------------- -def test_api_get_photometry(SETUP_CONFIG): - with tempfile.TemporaryDirectory() as tmpdir: - # Set cache to the temporary directory: - # FIXME: There is a potential race condition here! - config = load_config() - config.set('api', 'photometry_cache', tmpdir) - print(config) - - # The cache file should NOT exists: - assert not os.path.isfile(os.path.join(tmpdir, 'photometry-499.ecsv')), "Cache file already exists" - - # Download a photometry from API: - tab = api.get_photometry(499) - print(tab) - - # Basic tests of table: - assert isinstance(tab, Table) - assert len(tab) > 0 - assert 'starid' in tab.colnames - assert 'ra' in tab.colnames - assert 'decl' in tab.colnames - assert 'mag' in tab.colnames - assert 'mag_error' in tab.colnames - assert np.sum(tab['starid'] == 0) == 1, "There should be one starid=0" - - # Meta-information: - assert tab.meta['targetid'] == 2 - assert tab.meta['fileid'] == 179 - assert tab.meta['photfilter'] == 'B' - - # The cache file should now exists: - assert os.path.isfile(os.path.join(tmpdir, 'photometry-499.ecsv')), "Cache file does not exist" - - # Asking for the same photometry should now load from cache: - tab2 = api.get_photometry(499) - print(tab2) - - # The two tables should be identical: - assert tab2.meta == tab.meta - assert tab2.colnames == tab.colnames - for col in tab.colnames: - np.testing.assert_allclose(tab2[col], tab[col], equal_nan=True) - -#-------------------------------------------------------------------------------------------------- -if __name__ == '__main__': - pytest.main([__file__]) diff --git a/tests/test_catalogs.py b/tests/test_catalogs.py index c82d21f..9771d53 100644 --- a/tests/test_catalogs.py +++ b/tests/test_catalogs.py @@ -10,107 +10,92 @@ import numpy as np from astropy.coordinates import SkyCoord from astropy.table import Table -import conftest # noqa: F401 +import conftest # noqa: F401 from flows import catalogs -#-------------------------------------------------------------------------------------------------- + +# -------------------------------------------------------------------------------------------------- def test_query_simbad(): + # Coordinates around test-object (2019yvr): + coo_centre = SkyCoord(ra=256.727512, dec=30.271482, unit='deg', frame='icrs') - # Coordinates around test-object (2019yvr): - coo_centre = SkyCoord( - ra=256.727512, - dec=30.271482, - unit='deg', - frame='icrs' - ) + results, simbad = catalogs.query_simbad(coo_centre) - results, simbad = catalogs.query_simbad(coo_centre) + assert isinstance(results, Table) + assert isinstance(simbad, SkyCoord) + assert len(results) > 0 + results.pprint_all(50) - assert isinstance(results, Table) - assert isinstance(simbad, SkyCoord) - assert len(results) > 0 - results.pprint_all(50) -#-------------------------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------------------------- def test_query_skymapper(): + # Coordinates around test-object (2021aess): + coo_centre = SkyCoord(ra=53.4505, dec=-19.495725, unit='deg', frame='icrs') - # Coordinates around test-object (2021aess): - coo_centre = SkyCoord( - ra=53.4505, - dec=-19.495725, - unit='deg', - frame='icrs' - ) + results, skymapper = catalogs.query_skymapper(coo_centre) - results, skymapper = catalogs.query_skymapper(coo_centre) + assert isinstance(results, Table) + assert isinstance(skymapper, SkyCoord) + assert len(results) > 0 + results.pprint_all(50) - assert isinstance(results, Table) - assert isinstance(skymapper, SkyCoord) - assert len(results) > 0 - results.pprint_all(50) -#-------------------------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------------------------- -@pytest.mark.parametrize('ra,dec', [ - [256.727512, 30.271482], # 2019yvr - [58.59512, -19.18172], # 2009D -]) +@pytest.mark.parametrize('ra,dec', [[256.727512, 30.271482], # 2019yvr + [58.59512, -19.18172], # 2009D + ]) def test_download_catalog(SETUP_CONFIG, ra, dec): - - # Check if CasJobs have been configured, and skip the entire test if it isn't. - # This has to be done like this, to avoid problems when config.ini doesn't exist. - try: - catalogs.configure_casjobs() - except catalogs.CasjobsError: - pytest.skip("CasJobs not configured") - - # Coordinates around test-object (2019yvr): - coo_centre = SkyCoord( - ra=ra, - dec=dec, - unit='deg', - frame='icrs' - ) - - tab = catalogs.query_all(coo_centre) - print(tab) - - assert isinstance(tab, Table), "Should return a Table" - results = catalogs.convert_table_to_dict(tab) - - assert isinstance(results, list), "Should return a list" - for obj in results: - assert isinstance(obj, dict), "Each element should be a dict" - - # Check columns: - assert 'starid' in obj and obj['starid'] > 0 - assert 'ra' in obj and np.isfinite(obj['ra']) - assert 'decl' in obj and np.isfinite(obj['decl']) - assert 'pm_ra' in obj - assert 'pm_dec' in obj - assert 'gaia_mag' in obj - assert 'gaia_bp_mag' in obj - assert 'gaia_rp_mag' in obj - assert 'gaia_variability' in obj - assert 'B_mag' in obj - assert 'V_mag' in obj - assert 'u_mag' in obj - assert 'g_mag' in obj - assert 'r_mag' in obj - assert 'i_mag' in obj - assert 'z_mag' in obj - assert 'H_mag' in obj - assert 'J_mag' in obj - assert 'K_mag' in obj - - # All values should be finite number or None: - for key, val in obj.items(): - if key not in ('starid', 'gaia_variability'): - assert val is None or np.isfinite(val), f"{key} is not a valid value: {val}" - - # TODO: Manually check a target from this position if the merge is correct. - #assert False - -#-------------------------------------------------------------------------------------------------- + # Check if CasJobs have been configured, and skip the entire test if it isn't. + # This has to be done like this, to avoid problems when config.ini doesn't exist. + try: + catalogs.configure_casjobs() + except catalogs.CasjobsError: + pytest.skip("CasJobs not configured") + + # Coordinates around test-object (2019yvr): + coo_centre = SkyCoord(ra=ra, dec=dec, unit='deg', frame='icrs') + + tab = catalogs.query_all(coo_centre) + print(tab) + + assert isinstance(tab, Table), "Should return a Table" + results = catalogs.convert_table_to_dict(tab) + + assert isinstance(results, list), "Should return a list" + for obj in results: + assert isinstance(obj, dict), "Each element should be a dict" + + # Check columns: + assert 'starid' in obj and obj['starid'] > 0 + assert 'ra' in obj and np.isfinite(obj['ra']) + assert 'decl' in obj and np.isfinite(obj['decl']) + assert 'pm_ra' in obj + assert 'pm_dec' in obj + assert 'gaia_mag' in obj + assert 'gaia_bp_mag' in obj + assert 'gaia_rp_mag' in obj + assert 'gaia_variability' in obj + assert 'B_mag' in obj + assert 'V_mag' in obj + assert 'u_mag' in obj + assert 'g_mag' in obj + assert 'r_mag' in obj + assert 'i_mag' in obj + assert 'z_mag' in obj + assert 'H_mag' in obj + assert 'J_mag' in obj + assert 'K_mag' in obj + + # All values should be finite number or None: + for key, val in obj.items(): + if key not in ('starid', 'gaia_variability'): + assert val is None or np.isfinite(val), f"{key} is not a valid value: {val}" + + +# TODO: Manually check a target from this position if the merge is correct. +# assert False + +# -------------------------------------------------------------------------------------------------- if __name__ == '__main__': - pytest.main([__file__]) + pytest.main([__file__]) diff --git a/tests/test_load_image.py b/tests/test_load_image.py index 62b1219..0dc3ffa 100644 --- a/tests/test_load_image.py +++ b/tests/test_load_image.py @@ -1,9 +1,5 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- """ Test loading of images. - -.. codeauthor:: Rasmus Handberg """ import pytest @@ -12,8 +8,8 @@ from astropy.wcs import WCS from astropy.coordinates import SkyCoord import os.path -import conftest # noqa: F401 -from flows.api import get_filters +import conftest # noqa: F401 +from tendrils import api from flows.load_image import load_image @@ -30,8 +26,8 @@ #['2021aess_B01_20220207v1.fits.gz', 5], ]) def test_load_image(fpath, siteid): - # Get list of all available filters: - all_filters = set(get_filters().keys()) + # Get list of all available filters: + all_filters = set(api.get_filters().keys()) # The test input directory containing the test-images: INPUT_DIR = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'input') @@ -46,22 +42,23 @@ def test_load_image(fpath, siteid): # Load the image from the test-set: img = load_image(os.path.join(INPUT_DIR, fpath), target_coord=target_coord) - # Check the attributes of the image object: - assert isinstance(img.image, np.ndarray) - assert img.image.dtype in ('float32', 'float64') - assert isinstance(img.mask, np.ndarray) - assert img.mask.dtype == 'bool' - assert isinstance(img.clean, np.ma.MaskedArray) - assert img.clean.dtype == img.image.dtype - assert isinstance(img.obstime, Time) - assert isinstance(img.exptime, float) - assert img.exptime > 0 - assert isinstance(img.wcs, WCS) - assert isinstance(img.site, dict) - assert img.site['siteid'] == siteid - assert isinstance(img.photfilter, str) - assert img.photfilter in all_filters + # Check the attributes of the image object: + assert isinstance(img.image, np.ndarray) + assert img.image.dtype in ('float32', 'float64') + assert isinstance(img.mask, np.ndarray) + assert img.mask.dtype == 'bool' + assert isinstance(img.clean, np.ma.MaskedArray) + assert img.clean.dtype == img.image.dtype + assert isinstance(img.obstime, Time) + assert isinstance(img.exptime, float) + assert img.exptime > 0 + assert isinstance(img.wcs, WCS) + assert isinstance(img.site, dict) + assert img.site['siteid'] == siteid + assert isinstance(img.photfilter, str) + assert img.photfilter in all_filters -#-------------------------------------------------------------------------------------------------- + +# -------------------------------------------------------------------------------------------------- if __name__ == '__main__': pytest.main([__file__]) diff --git a/tests/test_photometry.py b/tests/test_photometry.py index 15b379f..3bab9cb 100644 --- a/tests/test_photometry.py +++ b/tests/test_photometry.py @@ -7,14 +7,17 @@ """ import pytest -import conftest # noqa: F401 -from flows import photometry # noqa: F401 +import conftest # noqa: F401 +from flows import photometry # noqa: F401 -#-------------------------------------------------------------------------------------------------- + +# -------------------------------------------------------------------------------------------------- def test_import_photometry(): - pass - #assert photometry + pass + + +# assert photometry -#-------------------------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------------------------- if __name__ == '__main__': - pytest.main([__file__]) + pytest.main([__file__]) diff --git a/tests/test_tns.py b/tests/test_tns.py index 98103c5..82c8d62 100644 --- a/tests/test_tns.py +++ b/tests/test_tns.py @@ -13,80 +13,72 @@ from conftest import capture_cli from flows import tns -#-------------------------------------------------------------------------------------------------- -@pytest.mark.skipif(os.environ.get('CI') == 'true', reason="Disabled on GitHub Actions to avoid too many requests HTTP error") + +# -------------------------------------------------------------------------------------------------- +@pytest.mark.skipif(os.environ.get('CI') == 'true', + reason="Disabled on GitHub Actions to avoid too many requests HTTP error") def test_tns_search(SETUP_CONFIG): + coo_centre = SkyCoord(ra=191.283890127, dec=-0.45909033652, unit='deg', frame='icrs') + res = tns.tns_search(coo_centre) - coo_centre = SkyCoord( - ra=191.283890127, - dec=-0.45909033652, - unit='deg', - frame='icrs' - ) - res = tns.tns_search(coo_centre) + print(res) + assert res[0]['objname'] == '2019yvr' + assert res[0]['prefix'] == 'SN' - print(res) - assert res[0]['objname'] == '2019yvr' - assert res[0]['prefix'] == 'SN' -#-------------------------------------------------------------------------------------------------- -@pytest.mark.skipif(os.environ.get('CI') == 'true', reason="Disabled on GitHub Actions to avoid too many requests HTTP error") +# -------------------------------------------------------------------------------------------------- +@pytest.mark.skipif(os.environ.get('CI') == 'true', + reason="Disabled on GitHub Actions to avoid too many requests HTTP error") def test_tns_get_obj(SETUP_CONFIG): + res = tns.tns_get_obj('2019yvr') - res = tns.tns_get_obj('2019yvr') + print(res) + assert res['objname'] == '2019yvr' + assert res['name_prefix'] == 'SN' - print(res) - assert res['objname'] == '2019yvr' - assert res['name_prefix'] == 'SN' -#-------------------------------------------------------------------------------------------------- -@pytest.mark.skipif(os.environ.get('CI') == 'true', reason="Disabled on GitHub Actions to avoid too many requests HTTP error") +# -------------------------------------------------------------------------------------------------- +@pytest.mark.skipif(os.environ.get('CI') == 'true', + reason="Disabled on GitHub Actions to avoid too many requests HTTP error") def test_tns_get_obj_noexist(SETUP_CONFIG): - res = tns.tns_get_obj('1892doesnotexist') - print(res) - assert res is None - -#-------------------------------------------------------------------------------------------------- -@pytest.mark.skipif(os.environ.get('CI') == 'true', reason="Disabled on GitHub Actions to avoid too many requests HTTP error") -@pytest.mark.parametrize('date_begin,date_end', [ - ('2019-01-01', '2019-02-01'), - (datetime.date(2019, 1, 1), datetime.date(2019, 2, 1)), - (datetime.datetime(2019, 1, 1, 12, 0), datetime.datetime(2019, 2, 1, 12, 0)) -]) + res = tns.tns_get_obj('1892doesnotexist') + print(res) + assert res is None + + +# -------------------------------------------------------------------------------------------------- +@pytest.mark.skipif(os.environ.get('CI') == 'true', + reason="Disabled on GitHub Actions to avoid too many requests HTTP error") +@pytest.mark.parametrize('date_begin,date_end', + [('2019-01-01', '2019-02-01'), (datetime.date(2019, 1, 1), datetime.date(2019, 2, 1)), + (datetime.datetime(2019, 1, 1, 12, 0), datetime.datetime(2019, 2, 1, 12, 0))]) def test_tns_getnames(SETUP_CONFIG, date_begin, date_end): + names = tns.tns_getnames(date_begin=date_begin, date_end=date_end, zmin=0, zmax=0.105, objtype=3) + + print(names) + assert isinstance(names, list), "Should return a list" + for n in names: + assert isinstance(n, str), "Each element should be a string" + assert n.startswith('SN'), "All names should begin with 'SN'" + assert 'SN2019A' in names, "SN2019A should be in the list" - names = tns.tns_getnames( - date_begin=date_begin, - date_end=date_end, - zmin=0, - zmax=0.105, - objtype=3 - ) - - print(names) - assert isinstance(names, list), "Should return a list" - for n in names: - assert isinstance(n, str), "Each element should be a string" - assert n.startswith('SN'), "All names should begin with 'SN'" - assert 'SN2019A' in names, "SN2019A should be in the list" - -#-------------------------------------------------------------------------------------------------- + +# -------------------------------------------------------------------------------------------------- def test_tns_getnames_wronginput(SETUP_CONFIG): - # Wrong dates should result in ValueError: - with pytest.raises(ValueError): - tns.tns_getnames( - date_begin=datetime.date(2019, 1, 1), - date_end=datetime.date(2017, 1, 1) - ) - -#-------------------------------------------------------------------------------------------------- -@pytest.mark.skipif(os.environ.get('CI') == 'true', reason="Disabled on GitHub Actions to avoid too many requests HTTP error") + # Wrong dates should result in ValueError: + with pytest.raises(ValueError): + tns.tns_getnames(date_begin=datetime.date(2019, 1, 1), date_end=datetime.date(2017, 1, 1)) + + +# -------------------------------------------------------------------------------------------------- +@pytest.mark.skipif(os.environ.get('CI') == 'true', + reason="Disabled on GitHub Actions to avoid too many requests HTTP error") def test_run_querytns(SETUP_CONFIG): + # Run the command line interface: + out, err, exitcode = capture_cli('run_querytns.py') + assert exitcode == 0 - # Run the command line interface: - out, err, exitcode = capture_cli('run_querytns.py') - assert exitcode == 0 -#-------------------------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------------------------- if __name__ == '__main__': - pytest.main([__file__]) + pytest.main([__file__]) diff --git a/tests/test_ztf.py b/tests/test_ztf.py index fa0c550..61ae68d 100644 --- a/tests/test_ztf.py +++ b/tests/test_ztf.py @@ -14,73 +14,62 @@ import tempfile import os from conftest import capture_cli -from flows import ztf +from tendrils.utils import ztf -#-------------------------------------------------------------------------------------------------- + +# -------------------------------------------------------------------------------------------------- def test_ztf_id(): + coo_centre = SkyCoord(ra=191.283890127, dec=-0.45909033652, unit='deg', frame='icrs') + ztfid = ztf.query_ztf_id(coo_centre) + assert ztfid == 'ZTF20aabqkxs' + + # With the correct discovery date we should get the same result: + dd = Time('2019-12-27 12:30:14', format='iso', scale='utc') + ztfid = ztf.query_ztf_id(coo_centre, discovery_date=dd) + assert ztfid == 'ZTF20aabqkxs' + + # With a wrong discovery date, we should not get a ZTF id: + dd = Time('2021-12-24 18:00:00', format='iso', scale='utc') + ztfid = ztf.query_ztf_id(coo_centre, discovery_date=dd) + assert ztfid is None + + coo_centre = SkyCoord(ra=181.6874198, dec=67.1649528, unit='deg', frame='icrs') + ztfid = ztf.query_ztf_id(coo_centre) + assert ztfid == 'ZTF21aatyplr' + - coo_centre = SkyCoord( - ra=191.283890127, - dec=-0.45909033652, - unit='deg', - frame='icrs' - ) - ztfid = ztf.query_ztf_id(coo_centre) - assert ztfid == 'ZTF20aabqkxs' - - # With the correct discovery date we should get the same result: - dd = Time('2019-12-27 12:30:14', format='iso', scale='utc') - ztfid = ztf.query_ztf_id(coo_centre, discovery_date=dd) - assert ztfid == 'ZTF20aabqkxs' - - # With a wrong discovery date, we should not get a ZTF id: - dd = Time('2021-12-24 18:00:00', format='iso', scale='utc') - ztfid = ztf.query_ztf_id(coo_centre, discovery_date=dd) - assert ztfid is None - - coo_centre = SkyCoord( - ra=181.6874198, - dec=67.1649528, - unit='deg', - frame='icrs' - ) - ztfid = ztf.query_ztf_id(coo_centre) - assert ztfid == 'ZTF21aatyplr' - -#-------------------------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------------------------- @pytest.mark.parametrize('targetid', [2, 865]) def test_ztf_photometry(SETUP_CONFIG, targetid): + tab = ztf.download_ztf_photometry(targetid) + print(tab) - tab = ztf.download_ztf_photometry(targetid) - print(tab) + assert isinstance(tab, Table) + assert 'time' in tab.colnames + assert 'photfilter' in tab.colnames + assert 'mag' in tab.colnames + assert 'mag_err' in tab.colnames + assert np.all(np.isfinite(tab['time'])) + assert np.all(np.isfinite(tab['mag'])) + assert np.all(np.isfinite(tab['mag_err'])) - assert isinstance(tab, Table) - assert 'time' in tab.colnames - assert 'photfilter' in tab.colnames - assert 'mag' in tab.colnames - assert 'mag_err' in tab.colnames - assert np.all(np.isfinite(tab['time'])) - assert np.all(np.isfinite(tab['mag'])) - assert np.all(np.isfinite(tab['mag_err'])) -#-------------------------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------------------------- @pytest.mark.parametrize('targetid', [2, 865]) def test_run_download_ztf(targetid): - with tempfile.TemporaryDirectory() as tmpdir: - # Nothing exists before running: - assert len(os.listdir(tmpdir)) == 0 - - # Run the command line interface: - out, err, exitcode = capture_cli('run_download_ztf.py', [ - f'--target={targetid:d}', - '-o', tmpdir - ]) - assert exitcode == 0 - - # The output directory should now have two files: - print(os.listdir(tmpdir)) - assert len(os.listdir(tmpdir)) == 2 - -#-------------------------------------------------------------------------------------------------- + with tempfile.TemporaryDirectory() as tmpdir: + # Nothing exists before running: + assert len(os.listdir(tmpdir)) == 0 + + # Run the command line interface: + out, err, exitcode = capture_cli('run_download_ztf.py', [f'--target={targetid:d}', '-o', tmpdir]) + assert exitcode == 0 + + # The output directory should now have two files: + print(os.listdir(tmpdir)) + assert len(os.listdir(tmpdir)) == 2 + + +# -------------------------------------------------------------------------------------------------- if __name__ == '__main__': - pytest.main([__file__]) + pytest.main([__file__])