From cbca2f3fd3eb1e39ad68f3eeae16f5d0ae3b85e3 Mon Sep 17 00:00:00 2001 From: Emir Date: Mon, 14 Feb 2022 21:30:27 +0100 Subject: [PATCH 1/3] change tab to space, fix PEP8 auto, prepare for black We move to a 4 space format, Automatically fix Pep8 mistakes that can be auto fixed, and change the Flake8 setup.cfg to match what will happen under black. Currently line limit is set at 120 but that can be discussed. --- flows/__init__.py | 1 + flows/aadc_db.py | 101 +- flows/api/catalogs.py | 217 ++-- flows/api/datafiles.py | 117 +- flows/api/filters.py | 35 +- flows/api/lightcurves.py | 60 +- flows/api/photometry_api.py | 309 +++-- flows/api/set_photometry_status.py | 147 +-- flows/api/sites.py | 64 +- flows/api/targets.py | 260 ++-- flows/catalogs.py | 1188 +++++++++---------- flows/config.py | 27 +- flows/coordinatematch/coordinatematch.py | 597 +++++----- flows/coordinatematch/wcs.py | 301 +++-- flows/epsfbuilder/epsfbuilder.py | 59 +- flows/load_image.py | 696 +++++------ flows/photometry.py | 1368 ++++++++++------------ flows/plots.py | 439 ++++--- flows/reference_cleaning.py | 595 +++++----- flows/run_imagematch.py | 295 +++-- flows/tns.py | 473 ++++---- flows/utilities.py | 29 +- flows/version.py | 200 ++-- flows/visibility.py | 227 ++-- flows/zeropoint.py | 87 +- flows/ztf.py | 252 ++-- notes/disk_covering_problem.py | 41 +- notes/fix_ztf_ids.py | 37 +- notes/update_all_catalogs.py | 80 +- run_catalogs.py | 72 +- run_download_ztf.py | 205 ++-- run_ingest.py | 1280 ++++++++++---------- run_photometry.py | 328 +++--- run_plotlc.py | 249 ++-- run_querytns.py | 217 ++-- run_upload_photometry.py | 62 +- run_visibility.py | 20 +- setup.cfg | 9 +- tests/conftest.py | 79 +- tests/test_api.py | 252 ++-- tests/test_catalogs.py | 163 ++- tests/test_load_image.py | 91 +- tests/test_photometry.py | 17 +- tests/test_tns.py | 114 +- tests/test_ztf.py | 105 +- 45 files changed, 5605 insertions(+), 5960 deletions(-) diff --git a/flows/__init__.py b/flows/__init__.py index 488a5bb..af45bd1 100644 --- a/flows/__init__.py +++ b/flows/__init__.py @@ -8,4 +8,5 @@ from .config import load_config from .version import get_version + __version__ = get_version(pep440=False) diff --git a/flows/aadc_db.py b/flows/aadc_db.py index 1e0a283..a4ce7c4 100644 --- a/flows/aadc_db.py +++ b/flows/aadc_db.py @@ -15,53 +15,54 @@ import getpass from .config import load_config -#-------------------------------------------------------------------------------------------------- -class AADC_DB(object): # pragma: no cover - """ - Connection to the central TASOC database. - - Attributes: - conn (`psycopg2.Connection` object): Connection to PostgreSQL database. - cursor (`psycopg2.Cursor` object): Cursor to use in database. - """ - - def __init__(self, username=None, password=None): - """ - Open connection to central TASOC database. - - If ``username`` or ``password`` is not provided or ``None``, - the user will be prompted for them. - - Parameters: - username (string or None, optional): Username for AADC database. - password (string or None, optional): Password for AADC database. - """ - - config = load_config() - - if username is None: - username = config.get('database', 'username', fallback=None) - if username is None: - default_username = getpass.getuser() - username = input('Username [%s]: ' % default_username) - if username == '': - username = default_username - - if password is None: - password = config.get('database', 'password', fallback=None) - if password is None: - password = getpass.getpass('Password: ') - - # Open database connection: - self.conn = psql.connect('host=10.28.0.127 user=' + username + ' password=' + password + ' dbname=db_aadc') - self.cursor = self.conn.cursor(cursor_factory=DictCursor) - - def close(self): - self.cursor.close() - self.conn.close() - - def __enter__(self): - return self - - def __exit__(self, *args, **kwargs): - self.close() + +# -------------------------------------------------------------------------------------------------- +class AADC_DB(object): # pragma: no cover + """ + Connection to the central TASOC database. + + Attributes: + conn (`psycopg2.Connection` object): Connection to PostgreSQL database. + cursor (`psycopg2.Cursor` object): Cursor to use in database. + """ + + def __init__(self, username=None, password=None): + """ + Open connection to central TASOC database. + + If ``username`` or ``password`` is not provided or ``None``, + the user will be prompted for them. + + Parameters: + username (string or None, optional): Username for AADC database. + password (string or None, optional): Password for AADC database. + """ + + config = load_config() + + if username is None: + username = config.get('database', 'username', fallback=None) + if username is None: + default_username = getpass.getuser() + username = input('Username [%s]: ' % default_username) + if username == '': + username = default_username + + if password is None: + password = config.get('database', 'password', fallback=None) + if password is None: + password = getpass.getpass('Password: ') + + # Open database connection: + self.conn = psql.connect('host=10.28.0.127 user=' + username + ' password=' + password + ' dbname=db_aadc') + self.cursor = self.conn.cursor(cursor_factory=DictCursor) + + def close(self): + self.cursor.close() + self.conn.close() + + def __enter__(self): + return self + + def __exit__(self, *args, **kwargs): + self.close() diff --git a/flows/api/catalogs.py b/flows/api/catalogs.py index 6657b89..20714b9 100644 --- a/flows/api/catalogs.py +++ b/flows/api/catalogs.py @@ -12,113 +12,116 @@ from functools import lru_cache from ..config import load_config -#-------------------------------------------------------------------------------------------------- + +# -------------------------------------------------------------------------------------------------- @lru_cache(maxsize=10) def get_catalog(target, radius=None, output='table'): - """ - - Parameters: - target (int or str): - radius (float, optional): Radius around target in degrees to return targets for. - outout (str, optional): Desired output format. Choises are 'table', 'dict', 'json'. - Default='table'. - - Returns: - dict: Dictionary with three members: - - 'target': Information about target. - - 'references': Table with information about reference stars close to target. - - 'avoid': Table with stars close to target which should be avoided in FOV selection. - - .. codeauthor:: Rasmus Handberg - """ - - assert output in ('table', 'json', 'dict'), "Invalid output format" - - # Get API token from config file: - config = load_config() - token = config.get('api', 'token', fallback=None) - if token is None: - raise RuntimeError("No API token has been defined") - - # - r = requests.get('https://flows.phys.au.dk/api/reference_stars.php', - params={'target': target}, - headers={'Authorization': 'Bearer ' + token}) - r.raise_for_status() - jsn = r.json() - - # Convert timestamps to actual Time objects: - jsn['target']['inserted'] = Time(jsn['target']['inserted'], scale='utc') - if jsn['target']['discovery_date'] is not None: - jsn['target']['discovery_date'] = Time(jsn['target']['discovery_date'], scale='utc') - - if output in ('json', 'dict'): - return jsn - - dict_tables = {} - - tab = Table( - names=('targetid', 'target_name', 'target_status', 'ra', 'decl', 'redshift', 'redshift_error', 'discovery_mag', 'catalog_downloaded', 'pointing_model_created', 'inserted', 'discovery_date', 'project', 'host_galaxy', 'ztf_id', 'sntype'), - dtype=('int32', 'str', 'str', 'float64', 'float64', 'float32', 'float32', 'float32', 'bool', 'bool', 'object', 'object', 'str', 'str', 'str', 'str'), - rows=[jsn['target']]) - - tab['ra'].description = 'Right ascension' - tab['ra'].unit = u.deg - tab['decl'].description = 'Declination' - tab['decl'].unit = u.deg - dict_tables['target'] = tab - - for table_name in ('references', 'avoid'): - tab = Table( - names=('starid', 'ra', 'decl', 'pm_ra', 'pm_dec', 'gaia_mag', 'gaia_bp_mag', 'gaia_rp_mag', 'gaia_variability', 'B_mag', 'V_mag', 'H_mag', 'J_mag', 'K_mag', 'u_mag', 'g_mag', 'r_mag', 'i_mag', 'z_mag', 'distance'), - dtype=('int64', 'float64', 'float64', 'float32', 'float32', 'float32', 'float32', 'float32', 'int32', 'float32', 'float32', 'float32', 'float32', 'float32', 'float32', 'float32', 'float32', 'float32', 'float32', 'float64'), - rows=jsn[table_name]) - - tab['starid'].description = 'Unique identifier in REFCAT2 catalog' - tab['ra'].description = 'Right ascension' - tab['ra'].unit = u.deg - tab['decl'].description = 'Declination' - tab['decl'].unit = u.deg - tab['pm_ra'].description = 'Proper motion in right ascension' - tab['pm_ra'].unit = u.mas/u.yr - tab['pm_dec'].description = 'Proper motion in declination' - tab['pm_dec'].unit = u.mas/u.yr - tab['distance'].description = 'Distance from object to target' - tab['distance'].unit = u.deg - - tab['gaia_mag'].description = 'Gaia G magnitude' - tab['gaia_bp_mag'].description = 'Gaia Bp magnitude' - tab['gaia_rp_mag'].description = 'Gaia Rp magnitude' - tab['gaia_variability'].description = 'Gaia variability classification' - tab['B_mag'].description = 'Johnson B magnitude' - tab['V_mag'].description = 'Johnson V magnitude' - tab['H_mag'].description = '2MASS H magnitude' - tab['J_mag'].description = '2MASS J magnitude' - tab['K_mag'].description = '2MASS K magnitude' - tab['u_mag'].description = 'u magnitude' - tab['g_mag'].description = 'g magnitude' - tab['r_mag'].description = 'r magnitude' - tab['i_mag'].description = 'i magnitude' - tab['z_mag'].description = 'z magnitude' - - # Add some meta-data to the table as well: - tab.meta['targetid'] = int(dict_tables['target']['targetid']) - - dict_tables[table_name] = tab - - return dict_tables - -#-------------------------------------------------------------------------------------------------- + """ + + Parameters: + target (int or str): + radius (float, optional): Radius around target in degrees to return targets for. + outout (str, optional): Desired output format. Choises are 'table', 'dict', 'json'. + Default='table'. + + Returns: + dict: Dictionary with three members: + - 'target': Information about target. + - 'references': Table with information about reference stars close to target. + - 'avoid': Table with stars close to target which should be avoided in FOV selection. + + .. codeauthor:: Rasmus Handberg + """ + + assert output in ('table', 'json', 'dict'), "Invalid output format" + + # Get API token from config file: + config = load_config() + token = config.get('api', 'token', fallback=None) + if token is None: + raise RuntimeError("No API token has been defined") + + # + r = requests.get('https://flows.phys.au.dk/api/reference_stars.php', params={'target': target}, + headers={'Authorization': 'Bearer ' + token}) + r.raise_for_status() + jsn = r.json() + + # Convert timestamps to actual Time objects: + jsn['target']['inserted'] = Time(jsn['target']['inserted'], scale='utc') + if jsn['target']['discovery_date'] is not None: + jsn['target']['discovery_date'] = Time(jsn['target']['discovery_date'], scale='utc') + + if output in ('json', 'dict'): + return jsn + + dict_tables = {} + + tab = Table(names=( + 'targetid', 'target_name', 'target_status', 'ra', 'decl', 'redshift', 'redshift_error', 'discovery_mag', + 'catalog_downloaded', 'pointing_model_created', 'inserted', 'discovery_date', 'project', 'host_galaxy', + 'ztf_id', 'sntype'), dtype=( + 'int32', 'str', 'str', 'float64', 'float64', 'float32', 'float32', 'float32', 'bool', 'bool', 'object', + 'object', 'str', 'str', 'str', 'str'), rows=[jsn['target']]) + + tab['ra'].description = 'Right ascension' + tab['ra'].unit = u.deg + tab['decl'].description = 'Declination' + tab['decl'].unit = u.deg + dict_tables['target'] = tab + + for table_name in ('references', 'avoid'): + tab = Table(names=( + 'starid', 'ra', 'decl', 'pm_ra', 'pm_dec', 'gaia_mag', 'gaia_bp_mag', 'gaia_rp_mag', 'gaia_variability', + 'B_mag', 'V_mag', 'H_mag', 'J_mag', 'K_mag', 'u_mag', 'g_mag', 'r_mag', 'i_mag', 'z_mag', 'distance'), + dtype=('int64', 'float64', 'float64', 'float32', 'float32', 'float32', 'float32', 'float32', 'int32', + 'float32', 'float32', 'float32', 'float32', 'float32', 'float32', 'float32', 'float32', 'float32', + 'float32', 'float64'), rows=jsn[table_name]) + + tab['starid'].description = 'Unique identifier in REFCAT2 catalog' + tab['ra'].description = 'Right ascension' + tab['ra'].unit = u.deg + tab['decl'].description = 'Declination' + tab['decl'].unit = u.deg + tab['pm_ra'].description = 'Proper motion in right ascension' + tab['pm_ra'].unit = u.mas / u.yr + tab['pm_dec'].description = 'Proper motion in declination' + tab['pm_dec'].unit = u.mas / u.yr + tab['distance'].description = 'Distance from object to target' + tab['distance'].unit = u.deg + + tab['gaia_mag'].description = 'Gaia G magnitude' + tab['gaia_bp_mag'].description = 'Gaia Bp magnitude' + tab['gaia_rp_mag'].description = 'Gaia Rp magnitude' + tab['gaia_variability'].description = 'Gaia variability classification' + tab['B_mag'].description = 'Johnson B magnitude' + tab['V_mag'].description = 'Johnson V magnitude' + tab['H_mag'].description = '2MASS H magnitude' + tab['J_mag'].description = '2MASS J magnitude' + tab['K_mag'].description = '2MASS K magnitude' + tab['u_mag'].description = 'u magnitude' + tab['g_mag'].description = 'g magnitude' + tab['r_mag'].description = 'r magnitude' + tab['i_mag'].description = 'i magnitude' + tab['z_mag'].description = 'z magnitude' + + # Add some meta-data to the table as well: + tab.meta['targetid'] = int(dict_tables['target']['targetid']) + + dict_tables[table_name] = tab + + return dict_tables + + +# -------------------------------------------------------------------------------------------------- def get_catalog_missing(): - - # Get API token from config file: - config = load_config() - token = config.get('api', 'token', fallback=None) - if token is None: - raise Exception("No API token has been defined") - - # - r = requests.get('https://flows.phys.au.dk/api/catalog_missing.php', - headers={'Authorization': 'Bearer ' + token}) - r.raise_for_status() - return r.json() + # Get API token from config file: + config = load_config() + token = config.get('api', 'token', fallback=None) + if token is None: + raise Exception("No API token has been defined") + + # + r = requests.get('https://flows.phys.au.dk/api/catalog_missing.php', headers={'Authorization': 'Bearer ' + token}) + r.raise_for_status() + return r.json() diff --git a/flows/api/datafiles.py b/flows/api/datafiles.py index 5cf7e31..5e0c614 100644 --- a/flows/api/datafiles.py +++ b/flows/api/datafiles.py @@ -10,70 +10,69 @@ from functools import lru_cache from ..config import load_config -#-------------------------------------------------------------------------------------------------- + +# -------------------------------------------------------------------------------------------------- @lru_cache(maxsize=10) def get_datafile(fileid): + # Get API token from config file: + config = load_config() + token = config.get('api', 'token', fallback=None) + if token is None: + raise RuntimeError("No API token has been defined") - # Get API token from config file: - config = load_config() - token = config.get('api', 'token', fallback=None) - if token is None: - raise RuntimeError("No API token has been defined") + r = requests.get('https://flows.phys.au.dk/api/datafiles.php', params={'fileid': fileid}, + headers={'Authorization': 'Bearer ' + token}) + r.raise_for_status() + jsn = r.json() - r = requests.get('https://flows.phys.au.dk/api/datafiles.php', - params={'fileid': fileid}, - headers={'Authorization': 'Bearer ' + token}) - r.raise_for_status() - jsn = r.json() + # Parse some of the fields to Python objects: + jsn['inserted'] = datetime.strptime(jsn['inserted'], '%Y-%m-%d %H:%M:%S.%f') + jsn['lastmodified'] = datetime.strptime(jsn['lastmodified'], '%Y-%m-%d %H:%M:%S.%f') - # Parse some of the fields to Python objects: - jsn['inserted'] = datetime.strptime(jsn['inserted'], '%Y-%m-%d %H:%M:%S.%f') - jsn['lastmodified'] = datetime.strptime(jsn['lastmodified'], '%Y-%m-%d %H:%M:%S.%f') + return jsn - return jsn -#-------------------------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------------------------- def get_datafiles(targetid=None, filt=None, minversion=None): - """ - Get list of data file IDs to be processed. - - Parameters: - targetid (int, optional): Target ID to process. - filt (str, optional): Filter the returned list: - - ``missing``: Return only data files that have not yet been processed. - - ``'all'``: Return all data files. - minversion (str, optional): Special filter matching files not processed at least with - the specified version (defined internally in API for now). - - Returns: - list: List of data files the can be processed. - - .. codeauthor:: Rasmus Handberg - """ - - # Validate input: - if filt is None: - filt = 'missing' - if filt not in ('missing', 'all', 'error'): - raise ValueError("Invalid filter specified: '%s'" % filt) - - # Get API token from config file: - config = load_config() - token = config.get('api', 'token', fallback=None) - if token is None: - raise RuntimeError("No API token has been defined") - - params = {} - if targetid is not None: - params['targetid'] = targetid - if minversion is not None: - params['minversion'] = minversion - params['filter'] = filt - - r = requests.get('https://flows.phys.au.dk/api/datafiles.php', - params=params, - headers={'Authorization': 'Bearer ' + token}) - r.raise_for_status() - jsn = r.json() - - return jsn + """ + Get list of data file IDs to be processed. + + Parameters: + targetid (int, optional): Target ID to process. + filt (str, optional): Filter the returned list: + - ``missing``: Return only data files that have not yet been processed. + - ``'all'``: Return all data files. + minversion (str, optional): Special filter matching files not processed at least with + the specified version (defined internally in API for now). + + Returns: + list: List of data files the can be processed. + + .. codeauthor:: Rasmus Handberg + """ + + # Validate input: + if filt is None: + filt = 'missing' + if filt not in ('missing', 'all', 'error'): + raise ValueError("Invalid filter specified: '%s'" % filt) + + # Get API token from config file: + config = load_config() + token = config.get('api', 'token', fallback=None) + if token is None: + raise RuntimeError("No API token has been defined") + + params = {} + if targetid is not None: + params['targetid'] = targetid + if minversion is not None: + params['minversion'] = minversion + params['filter'] = filt + + r = requests.get('https://flows.phys.au.dk/api/datafiles.php', params=params, + headers={'Authorization': 'Bearer ' + token}) + r.raise_for_status() + jsn = r.json() + + return jsn diff --git a/flows/api/filters.py b/flows/api/filters.py index 05e690c..49afcef 100644 --- a/flows/api/filters.py +++ b/flows/api/filters.py @@ -10,26 +10,25 @@ import astropy.units as u from ..config import load_config -#-------------------------------------------------------------------------------------------------- + +# -------------------------------------------------------------------------------------------------- @lru_cache(maxsize=10) def get_filters(): + # Get API token from config file: + config = load_config() + token = config.get('api', 'token', fallback=None) + if token is None: + raise RuntimeError("No API token has been defined") - # Get API token from config file: - config = load_config() - token = config.get('api', 'token', fallback=None) - if token is None: - raise RuntimeError("No API token has been defined") - - r = requests.get('https://flows.phys.au.dk/api/filters.php', - headers={'Authorization': 'Bearer ' + token}) - r.raise_for_status() - jsn = r.json() + r = requests.get('https://flows.phys.au.dk/api/filters.php', headers={'Authorization': 'Bearer ' + token}) + r.raise_for_status() + jsn = r.json() - # Add units: - for f, val in jsn.items(): - if val.get('wavelength_center'): - val['wavelength_center'] *= u.nm - if val.get('wavelength_width'): - val['wavelength_width'] *= u.nm + # Add units: + for f, val in jsn.items(): + if val.get('wavelength_center'): + val['wavelength_center'] *= u.nm + if val.get('wavelength_width'): + val['wavelength_width'] *= u.nm - return jsn + return jsn diff --git a/flows/api/lightcurves.py b/flows/api/lightcurves.py index a2eb22a..c89154e 100644 --- a/flows/api/lightcurves.py +++ b/flows/api/lightcurves.py @@ -12,43 +12,43 @@ from astropy.table import Table from ..config import load_config -#-------------------------------------------------------------------------------------------------- + +# -------------------------------------------------------------------------------------------------- def get_lightcurve(target): - """ - Retrieve lightcurve from Flows server. + """ + Retrieve lightcurve from Flows server. - Parameters: - target (int): Target to download lightcurve for. + Parameters: + target (int): Target to download lightcurve for. - Returns: - :class:`astropy.table.Table`: Table containing lightcurve. + Returns: + :class:`astropy.table.Table`: Table containing lightcurve. - TODO: - - Enable caching of files. + TODO: + - Enable caching of files. - .. codeauthor:: Rasmus Handberg - """ + .. codeauthor:: Rasmus Handberg + """ - # Get API token from config file: - config = load_config() - token = config.get('api', 'token', fallback=None) - if token is None: - raise RuntimeError("No API token has been defined") + # Get API token from config file: + config = load_config() + token = config.get('api', 'token', fallback=None) + if token is None: + raise RuntimeError("No API token has been defined") - # Send query to the Flows API: - params = {'target': target} - r = requests.get('https://flows.phys.au.dk/api/lightcurve.php', - params=params, - headers={'Authorization': 'Bearer ' + token}) - r.raise_for_status() + # Send query to the Flows API: + params = {'target': target} + r = requests.get('https://flows.phys.au.dk/api/lightcurve.php', params=params, + headers={'Authorization': 'Bearer ' + token}) + r.raise_for_status() - # Create tempory directory and save the file into there, - # then open the file as a Table: - with tempfile.TemporaryDirectory() as tmpdir: - tmpfile = os.path.join(tmpdir, 'table.ecsv') - with open(tmpfile, 'w') as fid: - fid.write(r.text) + # Create tempory directory and save the file into there, + # then open the file as a Table: + with tempfile.TemporaryDirectory() as tmpdir: + tmpfile = os.path.join(tmpdir, 'table.ecsv') + with open(tmpfile, 'w') as fid: + fid.write(r.text) - tab = Table.read(tmpfile, format='ascii.ecsv') + tab = Table.read(tmpfile, format='ascii.ecsv') - return tab + return tab diff --git a/flows/api/photometry_api.py b/flows/api/photometry_api.py index fdddf08..89629b5 100644 --- a/flows/api/photometry_api.py +++ b/flows/api/photometry_api.py @@ -19,162 +19,159 @@ from ..config import load_config from ..utilities import get_filehash -#-------------------------------------------------------------------------------------------------- + +# -------------------------------------------------------------------------------------------------- def get_photometry(photid): - """ - Retrieve lightcurve from Flows server. - - Please note that it can significantly speed up repeated calls to this function - to specify a cache directory in the config-file under api -> photometry_cache. - This will download the files only once and store them in this local cache for - use in subsequent calls. - - Parameters: - photid (int): Fileid for the photometry file. - - Returns: - :class:`astropy.table.Table`: Table containing photometry. - - .. codeauthor:: Rasmus Handberg - """ - - # Get API token from config file: - config = load_config() - token = config.get('api', 'token', fallback=None) - if token is None: - raise RuntimeError("No API token has been defined") - - # Determine where to store the downloaded file: - photcache = config.get('api', 'photometry_cache', fallback=None) - tmpdir = None - if photcache is not None: - photcache = os.path.abspath(photcache) - if not os.path.isdir(photcache): - raise FileNotFoundError(f"Photometry cache directory does not exist: {photcache}") - else: - tmpdir = tempfile.TemporaryDirectory(prefix='flows-api-get_photometry-') - photcache = tmpdir.name - - # Construct path to the photometry file in the cache: - photfile = os.path.join(photcache, f'photometry-{photid:d}.ecsv') - - if not os.path.isfile(photfile): - # Send query to the Flows API: - params = {'fileid': photid} - r = requests.get('https://flows.phys.au.dk/api/download_photometry.php', - params=params, - headers={'Authorization': 'Bearer ' + token}) - r.raise_for_status() - - # Create tempory directory and save the file into there, - # then open the file as a Table: - with open(photfile, 'w') as fid: - fid.write(r.text) - - # Read the photometry file: - tab = Table.read(photfile, format='ascii.ecsv') - - # Explicitly cleanup the tempoary directory if it was created: - if tmpdir: - tmpdir.cleanup() - - return tab - -#-------------------------------------------------------------------------------------------------- + """ + Retrieve lightcurve from Flows server. + + Please note that it can significantly speed up repeated calls to this function + to specify a cache directory in the config-file under api -> photometry_cache. + This will download the files only once and store them in this local cache for + use in subsequent calls. + + Parameters: + photid (int): Fileid for the photometry file. + + Returns: + :class:`astropy.table.Table`: Table containing photometry. + + .. codeauthor:: Rasmus Handberg + """ + + # Get API token from config file: + config = load_config() + token = config.get('api', 'token', fallback=None) + if token is None: + raise RuntimeError("No API token has been defined") + + # Determine where to store the downloaded file: + photcache = config.get('api', 'photometry_cache', fallback=None) + tmpdir = None + if photcache is not None: + photcache = os.path.abspath(photcache) + if not os.path.isdir(photcache): + raise FileNotFoundError(f"Photometry cache directory does not exist: {photcache}") + else: + tmpdir = tempfile.TemporaryDirectory(prefix='flows-api-get_photometry-') + photcache = tmpdir.name + + # Construct path to the photometry file in the cache: + photfile = os.path.join(photcache, f'photometry-{photid:d}.ecsv') + + if not os.path.isfile(photfile): + # Send query to the Flows API: + params = {'fileid': photid} + r = requests.get('https://flows.phys.au.dk/api/download_photometry.php', params=params, + headers={'Authorization': 'Bearer ' + token}) + r.raise_for_status() + + # Create tempory directory and save the file into there, + # then open the file as a Table: + with open(photfile, 'w') as fid: + fid.write(r.text) + + # Read the photometry file: + tab = Table.read(photfile, format='ascii.ecsv') + + # Explicitly cleanup the tempoary directory if it was created: + if tmpdir: + tmpdir.cleanup() + + return tab + + +# -------------------------------------------------------------------------------------------------- def upload_photometry(fileid, delete_completed=False): - """ - Upload photometry results to Flows server. - - This will make the uploaded photometry the active/newest/best photometry and - be used in plots and shown on the website. - - Parameters: - fileid (int): File ID of photometry to upload to server. - delete_completed (bool, optional): Delete the photometry from the local - working directory if the upload was successful. Default=False. - - .. codeauthor:: Rasmus Handberg - """ - - logger = logging.getLogger(__name__) - tqdm_settings = {'disable': None if logger.isEnabledFor(logging.INFO) else True} - - # Use API to get the datafile information: - datafile = api.get_datafile(fileid) - - # Get API token from config file: - config = load_config() - token = config.get('api', 'token', fallback=None) - if token is None: - raise RuntimeError("No API token has been defined") - photdir_root = config.get('photometry', 'output', fallback='.') - - # Find the photometry output directory for this fileid: - photdir = os.path.join(photdir_root, datafile['target_name'], f'{fileid:05d}') - if not os.path.isdir(photdir): - # Do a last check, to ensure that we have not just added the wrong number of zeros - # to the directory name: - found_photdir = [] - for d in os.listdir(os.path.join(photdir_root, datafile['target_name'])): - if d.isdigit() and int(d) == fileid and os.path.isdir(d): - found_photdir.append(os.path.join(photdir_root, datafile['target_name'], d)) - # If we only found one, use it, otherwise throw an exception: - if len(found_photdir) == 1: - photdir = found_photdir[0] - elif len(found_photdir) > 1: - raise RuntimeError(f"Several photometry output found for fileid={fileid}. \ + """ + Upload photometry results to Flows server. + + This will make the uploaded photometry the active/newest/best photometry and + be used in plots and shown on the website. + + Parameters: + fileid (int): File ID of photometry to upload to server. + delete_completed (bool, optional): Delete the photometry from the local + working directory if the upload was successful. Default=False. + + .. codeauthor:: Rasmus Handberg + """ + + logger = logging.getLogger(__name__) + tqdm_settings = {'disable': None if logger.isEnabledFor(logging.INFO) else True} + + # Use API to get the datafile information: + datafile = api.get_datafile(fileid) + + # Get API token from config file: + config = load_config() + token = config.get('api', 'token', fallback=None) + if token is None: + raise RuntimeError("No API token has been defined") + photdir_root = config.get('photometry', 'output', fallback='.') + + # Find the photometry output directory for this fileid: + photdir = os.path.join(photdir_root, datafile['target_name'], f'{fileid:05d}') + if not os.path.isdir(photdir): + # Do a last check, to ensure that we have not just added the wrong number of zeros + # to the directory name: + found_photdir = [] + for d in os.listdir(os.path.join(photdir_root, datafile['target_name'])): + if d.isdigit() and int(d) == fileid and os.path.isdir(d): + found_photdir.append(os.path.join(photdir_root, datafile['target_name'], d)) + # If we only found one, use it, otherwise throw an exception: + if len(found_photdir) == 1: + photdir = found_photdir[0] + elif len(found_photdir) > 1: + raise RuntimeError(f"Several photometry output found for fileid={fileid}. \ You need to do a cleanup of the photometry output directories.") - else: - raise FileNotFoundError(photdir) - - # Make sure required files are actually there: - photdir = os.path.abspath(photdir) - files_existing = os.listdir(photdir) - if 'photometry.ecsv' not in files_existing: - raise FileNotFoundError(os.path.join(photdir, 'photometry.ecsv')) - if 'photometry.log' not in files_existing: - raise FileNotFoundError(os.path.join(photdir, 'photometry.log')) - - # Create list of files to be uploaded: - files = [ - os.path.join(photdir, 'photometry.ecsv'), - os.path.join(photdir, 'photometry.log') - ] - files += glob.glob(os.path.join(photdir, '*.png')) - - # Create ZIP file: - with tempfile.TemporaryDirectory(prefix='flows-upload-') as tmpdir: - # Create ZIP-file within the temp directory: - fpath_zip = os.path.join(tmpdir, f'{fileid:05d}.zip') - - # Create ZIP file with all the files: - with zipfile.ZipFile(fpath_zip, 'w', allowZip64=True) as z: - for f in tqdm(files, desc=f'Zipping {fileid:d}', **tqdm_settings): - logger.debug('Zipping %s', f) - z.write(f, os.path.basename(f)) - - # Change the name of the uploaded file to contain the file hash: - fhash = get_filehash(fpath_zip) - fname_zip = f'{fileid:05d}-{fhash:s}.zip' - - # Send file to the API: - logger.info("Uploading to server...") - with open(fpath_zip, 'rb') as fid: - r = requests.post('https://flows.phys.au.dk/api/upload_photometry.php', - params={'fileid': fileid}, - files={'file': (fname_zip, fid, 'application/zip')}, - headers={'Authorization': 'Bearer ' + token}) - - # Check the returned data from the API: - if r.text.strip() != 'OK': - logger.error(r.text) - raise RuntimeError("An error occurred while uploading photometry: " + r.text) - r.raise_for_status() - - # If we have made it this far, the upload must have been a success: - if delete_completed: - if set([os.path.basename(f) for f in files]) == set(os.listdir(photdir)): - logger.info("Deleting photometry from workdir: '%s'", photdir) - shutil.rmtree(photdir, ignore_errors=True) - else: - logger.warning("Not deleting photometry from workdir: '%s'", photdir) + else: + raise FileNotFoundError(photdir) + + # Make sure required files are actually there: + photdir = os.path.abspath(photdir) + files_existing = os.listdir(photdir) + if 'photometry.ecsv' not in files_existing: + raise FileNotFoundError(os.path.join(photdir, 'photometry.ecsv')) + if 'photometry.log' not in files_existing: + raise FileNotFoundError(os.path.join(photdir, 'photometry.log')) + + # Create list of files to be uploaded: + files = [os.path.join(photdir, 'photometry.ecsv'), os.path.join(photdir, 'photometry.log')] + files += glob.glob(os.path.join(photdir, '*.png')) + + # Create ZIP file: + with tempfile.TemporaryDirectory(prefix='flows-upload-') as tmpdir: + # Create ZIP-file within the temp directory: + fpath_zip = os.path.join(tmpdir, f'{fileid:05d}.zip') + + # Create ZIP file with all the files: + with zipfile.ZipFile(fpath_zip, 'w', allowZip64=True) as z: + for f in tqdm(files, desc=f'Zipping {fileid:d}', **tqdm_settings): + logger.debug('Zipping %s', f) + z.write(f, os.path.basename(f)) + + # Change the name of the uploaded file to contain the file hash: + fhash = get_filehash(fpath_zip) + fname_zip = f'{fileid:05d}-{fhash:s}.zip' + + # Send file to the API: + logger.info("Uploading to server...") + with open(fpath_zip, 'rb') as fid: + r = requests.post('https://flows.phys.au.dk/api/upload_photometry.php', params={'fileid': fileid}, + files={'file': (fname_zip, fid, 'application/zip')}, + headers={'Authorization': 'Bearer ' + token}) + + # Check the returned data from the API: + if r.text.strip() != 'OK': + logger.error(r.text) + raise RuntimeError("An error occurred while uploading photometry: " + r.text) + r.raise_for_status() + + # If we have made it this far, the upload must have been a success: + if delete_completed: + if set([os.path.basename(f) for f in files]) == set(os.listdir(photdir)): + logger.info("Deleting photometry from workdir: '%s'", photdir) + shutil.rmtree(photdir, ignore_errors=True) + else: + logger.warning("Not deleting photometry from workdir: '%s'", photdir) diff --git a/flows/api/set_photometry_status.py b/flows/api/set_photometry_status.py index b0fc572..2f6dcb2 100644 --- a/flows/api/set_photometry_status.py +++ b/flows/api/set_photometry_status.py @@ -9,78 +9,79 @@ import requests from ..config import load_config -#-------------------------------------------------------------------------------------------------- + +# -------------------------------------------------------------------------------------------------- def set_photometry_status(fileid, status): - """ - Set photometry status. - - Parameters: - fileid (int): - status (str): Choises are 'running', 'error' or 'done'. - - .. codeauthor:: Rasmus Handberg - """ - # Validate the input: - logger = logging.getLogger(__name__) - if status not in ('running', 'error', 'abort', 'ingest', 'done'): - raise ValueError('Invalid status') - - # Get API token from config file: - config = load_config() - i_am_pipeline = config.getboolean('api', 'pipeline', fallback=False) - if not i_am_pipeline: - logger.debug("Not setting status since user is not pipeline") - return False - - # Get API token from config file: - token = config.get('api', 'token', fallback=None) - if token is None: - raise RuntimeError("No API token has been defined") - - # Send HTTP request to FLOWS server: - r = requests.get('https://flows.phys.au.dk/api/set_photometry_status.php', - params={'fileid': fileid, 'status': status}, - headers={'Authorization': 'Bearer ' + token}) - r.raise_for_status() - res = r.text.strip() - - if res != 'OK': - raise RuntimeError(res) - - return True - -#-------------------------------------------------------------------------------------------------- + """ + Set photometry status. + + Parameters: + fileid (int): + status (str): Choises are 'running', 'error' or 'done'. + + .. codeauthor:: Rasmus Handberg + """ + # Validate the input: + logger = logging.getLogger(__name__) + if status not in ('running', 'error', 'abort', 'ingest', 'done'): + raise ValueError('Invalid status') + + # Get API token from config file: + config = load_config() + i_am_pipeline = config.getboolean('api', 'pipeline', fallback=False) + if not i_am_pipeline: + logger.debug("Not setting status since user is not pipeline") + return False + + # Get API token from config file: + token = config.get('api', 'token', fallback=None) + if token is None: + raise RuntimeError("No API token has been defined") + + # Send HTTP request to FLOWS server: + r = requests.get('https://flows.phys.au.dk/api/set_photometry_status.php', + params={'fileid': fileid, 'status': status}, headers={'Authorization': 'Bearer ' + token}) + r.raise_for_status() + res = r.text.strip() + + if res != 'OK': + raise RuntimeError(res) + + return True + + +# -------------------------------------------------------------------------------------------------- def cleanup_photometry_status(): - """ - Perform a cleanup of the photometry status indicator. - - This will change all processes still marked as "running" - to "abort" if they have been running for more than a day. - - .. codeauthor:: Rasmus Handberg - """ - # Validate the input: - logger = logging.getLogger(__name__) - - # Get API token from config file: - config = load_config() - i_am_pipeline = config.getboolean('api', 'pipeline', fallback=False) - if not i_am_pipeline: - logger.debug("Not setting status since user is not pipeline") - return False - - # Get API token from config file: - token = config.get('api', 'token', fallback=None) - if token is None: - raise RuntimeError("No API token has been defined") - - # Send HTTP request to FLOWS server: - r = requests.get('https://flows.phys.au.dk/api/cleanup_photometry_status.php', - headers={'Authorization': 'Bearer ' + token}) - r.raise_for_status() - res = r.text.strip() - - if res != 'OK': - raise RuntimeError(res) - - return True + """ + Perform a cleanup of the photometry status indicator. + + This will change all processes still marked as "running" + to "abort" if they have been running for more than a day. + + .. codeauthor:: Rasmus Handberg + """ + # Validate the input: + logger = logging.getLogger(__name__) + + # Get API token from config file: + config = load_config() + i_am_pipeline = config.getboolean('api', 'pipeline', fallback=False) + if not i_am_pipeline: + logger.debug("Not setting status since user is not pipeline") + return False + + # Get API token from config file: + token = config.get('api', 'token', fallback=None) + if token is None: + raise RuntimeError("No API token has been defined") + + # Send HTTP request to FLOWS server: + r = requests.get('https://flows.phys.au.dk/api/cleanup_photometry_status.php', + headers={'Authorization': 'Bearer ' + token}) + r.raise_for_status() + res = r.text.strip() + + if res != 'OK': + raise RuntimeError(res) + + return True diff --git a/flows/api/sites.py b/flows/api/sites.py index 053310d..f18e955 100644 --- a/flows/api/sites.py +++ b/flows/api/sites.py @@ -11,44 +11,44 @@ from astropy.coordinates import EarthLocation from ..config import load_config -#-------------------------------------------------------------------------------------------------- + +# -------------------------------------------------------------------------------------------------- @lru_cache(maxsize=10) def get_site(siteid): + # Get API token from config file: + config = load_config() + token = config.get('api', 'token', fallback=None) + if token is None: + raise RuntimeError("No API token has been defined") - # Get API token from config file: - config = load_config() - token = config.get('api', 'token', fallback=None) - if token is None: - raise RuntimeError("No API token has been defined") + r = requests.get('https://flows.phys.au.dk/api/sites.php', params={'siteid': siteid}, + headers={'Authorization': 'Bearer ' + token}) + r.raise_for_status() + jsn = r.json() - r = requests.get('https://flows.phys.au.dk/api/sites.php', - params={'siteid': siteid}, - headers={'Authorization': 'Bearer ' + token}) - r.raise_for_status() - jsn = r.json() + # Special derived objects: + jsn['EarthLocation'] = EarthLocation(lat=jsn['latitude'] * u.deg, lon=jsn['longitude'] * u.deg, + height=jsn['elevation'] * u.m) - # Special derived objects: - jsn['EarthLocation'] = EarthLocation(lat=jsn['latitude']*u.deg, lon=jsn['longitude']*u.deg, height=jsn['elevation']*u.m) + return jsn - return jsn -#-------------------------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------------------------- @lru_cache(maxsize=1) def get_all_sites(): - - # Get API token from config file: - config = load_config() - token = config.get('api', 'token', fallback=None) - if token is None: - raise RuntimeError("No API token has been defined") - - r = requests.get('https://flows.phys.au.dk/api/sites.php', - headers={'Authorization': 'Bearer ' + token}) - r.raise_for_status() - jsn = r.json() - - # Special derived objects: - for site in jsn: - site['EarthLocation'] = EarthLocation(lat=site['latitude']*u.deg, lon=site['longitude']*u.deg, height=site['elevation']*u.m) - - return jsn + # Get API token from config file: + config = load_config() + token = config.get('api', 'token', fallback=None) + if token is None: + raise RuntimeError("No API token has been defined") + + r = requests.get('https://flows.phys.au.dk/api/sites.php', headers={'Authorization': 'Bearer ' + token}) + r.raise_for_status() + jsn = r.json() + + # Special derived objects: + for site in jsn: + site['EarthLocation'] = EarthLocation(lat=site['latitude'] * u.deg, lon=site['longitude'] * u.deg, + height=site['elevation'] * u.m) + + return jsn diff --git a/flows/api/targets.py b/flows/api/targets.py index 4d04216..8dd58fb 100644 --- a/flows/api/targets.py +++ b/flows/api/targets.py @@ -14,149 +14,135 @@ from functools import lru_cache from ..config import load_config -#-------------------------------------------------------------------------------------------------- + +# -------------------------------------------------------------------------------------------------- @lru_cache(maxsize=10) def get_target(target): + # Get API token from config file: + config = load_config() + token = config.get('api', 'token', fallback=None) + if token is None: + raise RuntimeError("No API token has been defined") - # Get API token from config file: - config = load_config() - token = config.get('api', 'token', fallback=None) - if token is None: - raise RuntimeError("No API token has been defined") + r = requests.get('https://flows.phys.au.dk/api/targets.php', params={'target': target}, + headers={'Authorization': 'Bearer ' + token}) + r.raise_for_status() + jsn = r.json() - r = requests.get('https://flows.phys.au.dk/api/targets.php', - params={'target': target}, - headers={'Authorization': 'Bearer ' + token}) - r.raise_for_status() - jsn = r.json() + # Parse some of the fields to Python objects: + jsn['inserted'] = datetime.strptime(jsn['inserted'], '%Y-%m-%d %H:%M:%S.%f') + if jsn['discovery_date']: + jsn['discovery_date'] = Time(jsn['discovery_date'], format='iso', scale='utc') - # Parse some of the fields to Python objects: - jsn['inserted'] = datetime.strptime(jsn['inserted'], '%Y-%m-%d %H:%M:%S.%f') - if jsn['discovery_date']: - jsn['discovery_date'] = Time(jsn['discovery_date'], format='iso', scale='utc') + return jsn - return jsn -#-------------------------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------------------------- @lru_cache(maxsize=1) def get_targets(): - - # Get API token from config file: - config = load_config() - token = config.get('api', 'token', fallback=None) - if token is None: - raise RuntimeError("No API token has been defined") - - r = requests.get('https://flows.phys.au.dk/api/targets.php', - headers={'Authorization': 'Bearer ' + token}) - r.raise_for_status() - jsn = r.json() - - # Parse some of the fields to Python objects: - for tgt in jsn: - tgt['inserted'] = datetime.strptime(tgt['inserted'], '%Y-%m-%d %H:%M:%S.%f') - if tgt['discovery_date']: - tgt['discovery_date'] = Time(tgt['discovery_date'], format='iso', scale='utc') - - return jsn - -#-------------------------------------------------------------------------------------------------- -def add_target(name, coord, redshift=None, redshift_error=None, discovery_date=None, - discovery_mag=None, host_galaxy=None, ztf=None, sntype=None, status='candidate', - project='flows'): - """ - Add new candidate or target. - - Coordinates are specified using an Astropy SkyCoord object, which can be - created in the following way: - - coord = SkyCoord(ra=19.1, dec=89.00001, unit='deg', frame='icrs') - - The easiest way is to specify ``discovery_date`` as an Astropy Time object: - - discovery_date = Time('2020-01-02 00:00:00', format='iso', scale='utc') - - Alternatively, you can also specify it as a :class:`datetime.datetime` object, - but some care has to be taken with specifying the correct timezone: - - discovery_date = datetime.strptime('2020-01-02 00:00:00', '%Y-%m-%d %H:%M:%S%f') - discovery_date = pytz.timezone('America/New_York').localize(ddate) - - Lastly, it can be given as a simple date-string of the following form, - but here the data has to be given in UTC: - - discovery_date = '2020-01-02 23:56:02.123' - - Parameters: - name (str): Name of target. Must be of the form "YYYYxyz", where YYYY is the year, - and xyz are letters. - coord (:class:ʼastropy.coordinates.SkyCoordʼ): Sky coordinates of target. - redshift (float, optional): Redshift. - redshift_error (float, optional): Uncertainty on redshift. - discovery_date (:class:`astropy.time.Time`, :class:`datetime.datetime` or str, optional): - discovery_mag (float, optional): Magnitude at time of discovery. - host_galaxy (str, optional): Host galaxy name. - sntype (str, optional): Supernovae type (e.g. Ia, Ib, II). - ztf (str, optional): ZTF identifier. - status (str, optional): - project (str, optional): - - Returns: - int: New target identifier in Flows system. - - .. codeauthor:: Rasmus Handberg - """ - # Check and convert input: - if not re.match(r'^[12]\d{3}([A-Z]|[a-z]{2,4})$', name.strip()): - raise ValueError("Invalid target name.") - - if redshift is None and redshift_error is not None: - raise ValueError("Redshift error specified without redshift value") - - if isinstance(discovery_date, Time): - discovery_date = discovery_date.utc.iso - elif isinstance(discovery_date, datetime): - discovery_date = discovery_date.astimezone(pytz.timezone('UTC')) - discovery_date = discovery_date.strftime('%Y-%m-%d %H:%M:%S%f') - elif isinstance(discovery_date, str): - discovery_date = datetime.strptime(discovery_date, '%Y-%m-%d %H:%M:%S%f') - discovery_date = discovery_date.strftime('%Y-%m-%d %H:%M:%S%f') - - if status not in ('candidate', 'target'): - raise ValueError("Invalid target status.") - - # Get API token from config file: - config = load_config() - token = config.get('api', 'token', fallback=None) - if token is None: - raise RuntimeError("No API token has been defined") - - # Gather parameters to be sent to API: - params = { - 'targetid': 0, - 'target_name': name.strip(), - 'ra': coord.icrs.ra.deg, - 'decl': coord.icrs.dec.deg, - 'redshift': redshift, - 'redshift_error': redshift_error, - 'discovery_date': discovery_date, - 'discovery_mag': discovery_mag, - 'host_galaxy': host_galaxy, - 'project': project, - 'ztf_id': ztf, - 'target_status': status, - 'sntype': sntype - } - - # Post the request to the API: - r = requests.post('https://flows.phys.au.dk/api/targets_add.php', - data=params, - headers={'Authorization': 'Bearer ' + token}) - r.raise_for_status() - jsn = r.json() - - # Check for errors: - if jsn['errors'] is not None: - raise RuntimeError(f"Adding target '{name}' resulted in an error: {jsn['errors']}") - - return int(jsn['targetid']) + # Get API token from config file: + config = load_config() + token = config.get('api', 'token', fallback=None) + if token is None: + raise RuntimeError("No API token has been defined") + + r = requests.get('https://flows.phys.au.dk/api/targets.php', headers={'Authorization': 'Bearer ' + token}) + r.raise_for_status() + jsn = r.json() + + # Parse some of the fields to Python objects: + for tgt in jsn: + tgt['inserted'] = datetime.strptime(tgt['inserted'], '%Y-%m-%d %H:%M:%S.%f') + if tgt['discovery_date']: + tgt['discovery_date'] = Time(tgt['discovery_date'], format='iso', scale='utc') + + return jsn + + +# -------------------------------------------------------------------------------------------------- +def add_target(name, coord, redshift=None, redshift_error=None, discovery_date=None, discovery_mag=None, + host_galaxy=None, ztf=None, sntype=None, status='candidate', project='flows'): + """ + Add new candidate or target. + + Coordinates are specified using an Astropy SkyCoord object, which can be + created in the following way: + + coord = SkyCoord(ra=19.1, dec=89.00001, unit='deg', frame='icrs') + + The easiest way is to specify ``discovery_date`` as an Astropy Time object: + + discovery_date = Time('2020-01-02 00:00:00', format='iso', scale='utc') + + Alternatively, you can also specify it as a :class:`datetime.datetime` object, + but some care has to be taken with specifying the correct timezone: + + discovery_date = datetime.strptime('2020-01-02 00:00:00', '%Y-%m-%d %H:%M:%S%f') + discovery_date = pytz.timezone('America/New_York').localize(ddate) + + Lastly, it can be given as a simple date-string of the following form, + but here the data has to be given in UTC: + + discovery_date = '2020-01-02 23:56:02.123' + + Parameters: + name (str): Name of target. Must be of the form "YYYYxyz", where YYYY is the year, + and xyz are letters. + coord (:class:ʼastropy.coordinates.SkyCoordʼ): Sky coordinates of target. + redshift (float, optional): Redshift. + redshift_error (float, optional): Uncertainty on redshift. + discovery_date (:class:`astropy.time.Time`, :class:`datetime.datetime` or str, optional): + discovery_mag (float, optional): Magnitude at time of discovery. + host_galaxy (str, optional): Host galaxy name. + sntype (str, optional): Supernovae type (e.g. Ia, Ib, II). + ztf (str, optional): ZTF identifier. + status (str, optional): + project (str, optional): + + Returns: + int: New target identifier in Flows system. + + .. codeauthor:: Rasmus Handberg + """ + # Check and convert input: + if not re.match(r'^[12]\d{3}([A-Z]|[a-z]{2,4})$', name.strip()): + raise ValueError("Invalid target name.") + + if redshift is None and redshift_error is not None: + raise ValueError("Redshift error specified without redshift value") + + if isinstance(discovery_date, Time): + discovery_date = discovery_date.utc.iso + elif isinstance(discovery_date, datetime): + discovery_date = discovery_date.astimezone(pytz.timezone('UTC')) + discovery_date = discovery_date.strftime('%Y-%m-%d %H:%M:%S%f') + elif isinstance(discovery_date, str): + discovery_date = datetime.strptime(discovery_date, '%Y-%m-%d %H:%M:%S%f') + discovery_date = discovery_date.strftime('%Y-%m-%d %H:%M:%S%f') + + if status not in ('candidate', 'target'): + raise ValueError("Invalid target status.") + + # Get API token from config file: + config = load_config() + token = config.get('api', 'token', fallback=None) + if token is None: + raise RuntimeError("No API token has been defined") + + # Gather parameters to be sent to API: + params = {'targetid': 0, 'target_name': name.strip(), 'ra': coord.icrs.ra.deg, 'decl': coord.icrs.dec.deg, + 'redshift': redshift, 'redshift_error': redshift_error, 'discovery_date': discovery_date, + 'discovery_mag': discovery_mag, 'host_galaxy': host_galaxy, 'project': project, 'ztf_id': ztf, + 'target_status': status, 'sntype': sntype} + + # Post the request to the API: + r = requests.post('https://flows.phys.au.dk/api/targets_add.php', data=params, + headers={'Authorization': 'Bearer ' + token}) + r.raise_for_status() + jsn = r.json() + + # Check for errors: + if jsn['errors'] is not None: + raise RuntimeError(f"Adding target '{name}' resulted in an error: {jsn['errors']}") + + return int(jsn['targetid']) diff --git a/flows/catalogs.py b/flows/catalogs.py index 264fc23..2d7ca98 100644 --- a/flows/catalogs.py +++ b/flows/catalogs.py @@ -24,613 +24,584 @@ from .aadc_db import AADC_DB from .ztf import query_ztf_id -#-------------------------------------------------------------------------------------------------- + +# -------------------------------------------------------------------------------------------------- class CasjobsError(RuntimeError): - pass + pass + -#-------------------------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------------------------- class CasjobsMemoryError(RuntimeError): - pass + pass + -#-------------------------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------------------------- def floatval(value): - return None if value == '' or value == 'NA' or value == '0' else float(value) + return None if value == '' or value == 'NA' or value == '0' else float(value) -#-------------------------------------------------------------------------------------------------- + +# -------------------------------------------------------------------------------------------------- def intval(value): - return None if value == '' else int(value) + return None if value == '' else int(value) + -#-------------------------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------------------------- def configure_casjobs(overwrite=False): - """ - Set up CasJobs if needed. - - Parameters: - overwrite (bool, optional): Overwrite existing configuration. Default (False) is to not - overwrite existing configuration. - - .. codeauthor:: Rasmus Handberg - """ - - __dir__ = os.path.dirname(os.path.realpath(__file__)) - casjobs_config = os.path.join(__dir__, 'casjobs', 'CasJobs.config') - if os.path.isfile(casjobs_config) and not overwrite: - return - - config = load_config() - wsid = config.get('casjobs', 'wsid', fallback=None) - passwd = config.get('casjobs', 'password', fallback=None) - if wsid is None or passwd is None: - raise CasjobsError("CasJobs WSID and PASSWORD not in config.ini") - - try: - with open(casjobs_config, 'w') as fid: - fid.write("wsid={0:s}\n".format(wsid)) - fid.write("password={0:s}\n".format(passwd)) - fid.write("default_target=HLSP_ATLAS_REFCAT2\n") - fid.write("default_queue=1\n") - fid.write("default_days=1\n") - fid.write("verbose=false\n") - fid.write("debug=false\n") - fid.write("jobs_location=http://mastweb.stsci.edu/gcasjobs/services/jobs.asmx\n") - except: # noqa: E722, pragma: no cover - if os.path.isfile(casjobs_config): - os.remove(casjobs_config) - -#-------------------------------------------------------------------------------------------------- -def query_casjobs_refcat2(coo_centre, radius=24*u.arcmin): - """ - Uses the CasJobs program to do a cone-search around the position. - - Will first attempt to do single large cone-search, and if that - fails because of CasJobs memory limits, will sub-divide the cone - into smaller queries. - - Parameters: - coo_centre (:class:`astropy.coordinates.SkyCoord`): Coordinates of centre of search cone. - radius (Angle, optional): Search radius. Default is 24 arcmin. - - Returns: - list: List of dicts with the REFCAT2 information. - - .. codeauthor:: Rasmus Handberg - """ - - logger = logging.getLogger(__name__) - if isinstance(radius, (float, int)): - radius *= u.deg - - try: - results = _query_casjobs_refcat2(coo_centre, radius=radius) - except CasjobsMemoryError: - logger.debug("CasJobs failed with memory error. Trying to use smaller radii.") - results = _query_casjobs_refcat2_divide_and_conquer(coo_centre, radius=radius) - - # Remove duplicate entries: - _, indx = np.unique([res['starid'] for res in results], return_index=True) - results = [results[k] for k in indx] - - # Trim away anything outside radius: - ra = [res['ra'] for res in results] - decl = [res['decl'] for res in results] - coords = SkyCoord(ra=ra, dec=decl, unit='deg', frame='icrs') - sep = coords.separation(coo_centre) - results = [res for k,res in enumerate(results) if sep[k] <= radius] - - logger.debug("Found %d unique results", len(results)) - return results - -#-------------------------------------------------------------------------------------------------- + """ + Set up CasJobs if needed. + + Parameters: + overwrite (bool, optional): Overwrite existing configuration. Default (False) is to not + overwrite existing configuration. + + .. codeauthor:: Rasmus Handberg + """ + + __dir__ = os.path.dirname(os.path.realpath(__file__)) + casjobs_config = os.path.join(__dir__, 'casjobs', 'CasJobs.config') + if os.path.isfile(casjobs_config) and not overwrite: + return + + config = load_config() + wsid = config.get('casjobs', 'wsid', fallback=None) + passwd = config.get('casjobs', 'password', fallback=None) + if wsid is None or passwd is None: + raise CasjobsError("CasJobs WSID and PASSWORD not in config.ini") + + try: + with open(casjobs_config, 'w') as fid: + fid.write("wsid={0:s}\n".format(wsid)) + fid.write("password={0:s}\n".format(passwd)) + fid.write("default_target=HLSP_ATLAS_REFCAT2\n") + fid.write("default_queue=1\n") + fid.write("default_days=1\n") + fid.write("verbose=false\n") + fid.write("debug=false\n") + fid.write("jobs_location=http://mastweb.stsci.edu/gcasjobs/services/jobs.asmx\n") + except: # noqa: E722, pragma: no cover + if os.path.isfile(casjobs_config): + os.remove(casjobs_config) + + +# -------------------------------------------------------------------------------------------------- +def query_casjobs_refcat2(coo_centre, radius=24 * u.arcmin): + """ + Uses the CasJobs program to do a cone-search around the position. + + Will first attempt to do single large cone-search, and if that + fails because of CasJobs memory limits, will sub-divide the cone + into smaller queries. + + Parameters: + coo_centre (:class:`astropy.coordinates.SkyCoord`): Coordinates of centre of search cone. + radius (Angle, optional): Search radius. Default is 24 arcmin. + + Returns: + list: List of dicts with the REFCAT2 information. + + .. codeauthor:: Rasmus Handberg + """ + + logger = logging.getLogger(__name__) + if isinstance(radius, (float, int)): + radius *= u.deg + + try: + results = _query_casjobs_refcat2(coo_centre, radius=radius) + except CasjobsMemoryError: + logger.debug("CasJobs failed with memory error. Trying to use smaller radii.") + results = _query_casjobs_refcat2_divide_and_conquer(coo_centre, radius=radius) + + # Remove duplicate entries: + _, indx = np.unique([res['starid'] for res in results], return_index=True) + results = [results[k] for k in indx] + + # Trim away anything outside radius: + ra = [res['ra'] for res in results] + decl = [res['decl'] for res in results] + coords = SkyCoord(ra=ra, dec=decl, unit='deg', frame='icrs') + sep = coords.separation(coo_centre) + results = [res for k, res in enumerate(results) if sep[k] <= radius] + + logger.debug("Found %d unique results", len(results)) + return results + + +# -------------------------------------------------------------------------------------------------- def _query_casjobs_refcat2_divide_and_conquer(coo_centre, radius): - logger = logging.getLogger(__name__) - - # Just put in a stop criterion to avoid infinite recursion: - if radius < 0.04*u.deg: - raise Exception("Too many subdivides") - - # Search central cone: - try: - results = _query_casjobs_refcat2(coo_centre, radius=0.5*radius) - except CasjobsMemoryError: - logger.debug("CasJobs failed with memory error. Trying to use smaller radii.") - results = _query_casjobs_refcat2_divide_and_conquer(coo_centre, radius=0.5*radius) - - # Search six cones around central cone: - for n in range(6): - # FIXME: The 0.8 here is kind of a guess. There should be an analytic solution - new = SkyCoord( - ra=coo_centre.ra.deg + 0.8 * Angle(radius).deg * np.cos(n*60*np.pi/180), - dec=coo_centre.dec.deg + 0.8 * Angle(radius).deg * np.sin(n*60*np.pi/180), - unit='deg', frame='icrs') - - try: - results += _query_casjobs_refcat2(new, radius=0.5*radius) - except CasjobsMemoryError: - logger.debug("CasJobs failed with memory error. Trying to use smaller radii.") - results += _query_casjobs_refcat2_divide_and_conquer(new, radius=0.5*radius) - - return results - -#-------------------------------------------------------------------------------------------------- -def _query_casjobs_refcat2(coo_centre, radius=24*u.arcmin): - """ - Uses the CasJobs program to do a cone-search around the position. - - Parameters: - coo_centre (:class:`astropy.coordinates.SkyCoord`): Coordinates of centre of search cone. - radius (Angle, optional): Search radius. Default is 24 arcmin. - - Returns: - list: List of dicts with the REFCAT2 information. - - .. codeauthor:: Rasmus Handberg - """ - - logger = logging.getLogger(__name__) - if isinstance(radius, (float, int)): - radius *= u.deg - - sql = "SELECT r.* FROM fGetNearbyObjEq({ra:f}, {dec:f}, {radius:f}) AS n INNER JOIN HLSP_ATLAS_REFCAT2.refcat2 AS r ON n.objid=r.objid ORDER BY n.distance;".format( - ra=coo_centre.ra.deg, - dec=coo_centre.dec.deg, - radius=Angle(radius).deg - ) - logger.debug(sql) - - # Make sure that CasJobs have been configured: - configure_casjobs() - - # The command to run the casjobs script: - # BEWARE: This may change in the future without warning - it has before! - cmd = 'java -jar casjobs.jar execute "{0:s}"'.format(sql) - - # Execute the command: - cmd = shlex.split(cmd) - directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'casjobs') - proc = subprocess.Popen(cmd, cwd=directory, stdout=subprocess.PIPE, universal_newlines=True) - stdout, stderr = proc.communicate() - output = stdout.split("\n") - - # build list of all kois from output from the CasJobs-script: - error_thrown = False - results = [] - for line in output: - line = line.strip() - if line == '': - continue - if 'ERROR' in line: - error_thrown = True - break - - row = line.split(',') - if len(row) == 45 and row[0] != '[objid]:Integer': - results.append({ - 'starid': int(row[0]), - 'ra': floatval(row[1]), - 'decl': floatval(row[2]), - 'pm_ra': floatval(row[5]), - 'pm_dec': floatval(row[7]), - 'gaia_mag': floatval(row[9]), - 'gaia_bp_mag': floatval(row[11]), - 'gaia_rp_mag': floatval(row[13]), - 'gaia_variability': intval(row[17]), - 'g_mag': floatval(row[22]), - 'r_mag': floatval(row[26]), - 'i_mag': floatval(row[30]), - 'z_mag': floatval(row[34]), - 'J_mag': floatval(row[39]), - 'H_mag': floatval(row[41]), - 'K_mag': floatval(row[43]), - }) - - if error_thrown: - error_msg = '' - for line in output: - if len(line.strip()) > 0: - error_msg += line.strip() + "\n" - - logger.debug("Error Msg: %s", error_msg) - if 'query results exceed memory limit' in error_msg.lower(): - raise CasjobsMemoryError("Query results exceed memory limit") - else: - raise CasjobsError("ERROR detected in CasJobs: " + error_msg) - - if not results: - raise CasjobsError("Could not find anything on CasJobs") - - logger.debug("Found %d results", len(results)) - return results - -#-------------------------------------------------------------------------------------------------- -def query_apass(coo_centre, radius=24*u.arcmin): - """ - Queries APASS catalog using cone-search around the position. - - Parameters: - coo_centre (:class:`astropy.coordinates.SkyCoord`): Coordinates of centre of search cone. - radius (float, optional): - - Returns: - list: List of dicts with the APASS information. - - .. codeauthor:: Rasmus Handberg - """ - - # https://vizier.u-strasbg.fr/viz-bin/VizieR-3?-source=II/336 - - if isinstance(radius, (float, int)): - radius *= u.deg - - data = { - 'ra': coo_centre.icrs.ra.deg, - 'dec': coo_centre.icrs.dec.deg, - 'radius': Angle(radius).deg, - 'outtype': '1' - } - - res = requests.post('https://www.aavso.org/cgi-bin/apass_dr10_download.pl', data=data) - res.raise_for_status() - - results = [] - - lines = res.text.split("\n") - #header = lines[0] - - for line in lines[1:]: - if line.strip() == '': continue - row = line.strip().split(',') - - results.append({ - 'ra': floatval(row[0]), - 'decl': floatval(row[2]), - 'V_mag': floatval(row[4]), - 'B_mag': floatval(row[7]), - 'u_mag': floatval(row[10]), - 'g_mag': floatval(row[13]), - 'r_mag': floatval(row[16]), - 'i_mag': floatval(row[19]), - 'z_mag': floatval(row[22]), - 'Y_mag': floatval(row[25]) - }) - - return results - -#-------------------------------------------------------------------------------------------------- -def query_sdss(coo_centre, radius=24*u.arcmin, dr=16, clean=True): - """ - Queries SDSS catalog using cone-search around the position. - - Parameters: - coo_centre (:class:`astropy.coordinates.SkyCoord`): Coordinates of centre of search cone. - radius (float, optional): - dr (int, optional): SDSS Data Release to query. Default=16. - clean (bool, optional): Clean results for stars only and no other problems. - - Returns: - tuple: - - :class:`astropy.table.Table`: Table with SDSS information. - - :class:`astropy.coordinates.SkyCoord`: Sky coordinates for SDSS objects. - - .. codeauthor:: Emir Karamehmetoglu - .. codeauthor:: Rasmus Handberg - """ - - if isinstance(radius, (float, int)): - radius *= u.deg - - AT_sdss = SDSS.query_region(coo_centre, - photoobj_fields=['type', 'clean', 'ra', 'dec', 'psfMag_u'], - data_release=dr, - timeout=600, - radius=radius) - - if AT_sdss is None: - return None, None - - if clean: - # Clean SDSS following https://www.sdss.org/dr12/algorithms/photo_flags_recommend/ - # 6 == star, clean means remove interp, edge, suspicious defects, deblending problems, duplicates. - AT_sdss = AT_sdss[(AT_sdss['type'] == 6) & (AT_sdss['clean'] == 1)] - - # Remove these columns since they are no longer needed: - AT_sdss.remove_columns(['type', 'clean']) - - if len(AT_sdss) == 0: - return None, None - - # Create SkyCoord object with the coordinates: - sdss = SkyCoord( - ra=AT_sdss['ra'], - dec=AT_sdss['dec'], - unit=u.deg, - frame='icrs') - - return AT_sdss, sdss - -#-------------------------------------------------------------------------------------------------- -def query_simbad(coo_centre, radius=24*u.arcmin): - """ - Query SIMBAD using cone-search around the position using astroquery. - - Parameters: - coo_centre (:class:`astropy.coordinates.SkyCoord`): Coordinates of centre of search cone. - radius (float, optional): - - Returns: - list: Astropy Table with SIMBAD information. - - .. codeauthor:: Rasmus Handberg - """ - - s = Simbad() - s.ROW_LIMIT = 0 - s.remove_votable_fields('coordinates') - s.add_votable_fields('ra(d;A;ICRS;J2000)', 'dec(d;D;ICRS;2000)', 'pmra', 'pmdec') - s.add_votable_fields('otype') - s.add_votable_fields('flux(B)', 'flux(V)', 'flux(R)', 'flux(I)', 'flux(J)', 'flux(H)', 'flux(K)') - s.add_votable_fields('flux(u)', 'flux(g)', 'flux(r)', 'flux(i)', 'flux(z)') - - rad = Angle(radius).arcmin - with warnings.catch_warnings(): - warnings.filterwarnings('ignore', category=UserWarning) - results = s.query_criteria(f'region(circle, icrs, {coo_centre.icrs.ra.deg:.10f} {coo_centre.icrs.dec.deg:+.10f}, {rad}m)', otypes='Star') - - if not results: - return None, None - - # Rename columns: - results.rename_column('MAIN_ID', 'main_id') - results.rename_column('RA_d_A_ICRS_J2000', 'ra') - results.rename_column('DEC_d_D_ICRS_2000', 'dec') - results.rename_column('PMRA', 'pmra') - results.rename_column('PMDEC', 'pmdec') - results.rename_column('FLUX_B', 'B_mag') - results.rename_column('FLUX_V', 'V_mag') - results.rename_column('FLUX_R', 'R_mag') - results.rename_column('FLUX_I', 'I_mag') - results.rename_column('FLUX_J', 'J_mag') - results.rename_column('FLUX_H', 'H_mag') - results.rename_column('FLUX_K', 'K_mag') - results.rename_column('FLUX_u', 'u_mag') - results.rename_column('FLUX_g', 'g_mag') - results.rename_column('FLUX_r', 'r_mag') - results.rename_column('FLUX_i', 'i_mag') - results.rename_column('FLUX_z', 'z_mag') - results.rename_column('OTYPE', 'otype') - results.remove_column('SCRIPT_NUMBER_ID') - results.sort(['V_mag', 'B_mag', 'H_mag']) - - # Filter out object types which shouldn'r really be in there anyway: - indx = (results['otype'] == 'Galaxy') | (results['otype'] == 'LINER') | (results['otype'] == 'SN') - results = results[~indx] - - if len(results) == 0: - return None, None - - # Build sky coordinates object: - simbad = SkyCoord( - ra=results['ra'], - dec=results['dec'], - pm_ra_cosdec=results['pmra'], - pm_dec=results['pmdec'], - frame='icrs', - obstime='J2000') - - return results, simbad - -#-------------------------------------------------------------------------------------------------- -def query_skymapper(coo_centre, radius=24*u.arcmin): - """ - Queries SkyMapper catalog using cone-search around the position. - - Parameters: - coo_centre (:class:`astropy.coordinates.SkyCoord`): Coordinates of centre of search cone. - radius (float, optional): - - Returns: - tuple: - - :class:`astropy.table.Table`: Astropy Table with SkyMapper information. - - :class:`astropy.coordinates.SkyCoord`: - - .. codeauthor:: Rasmus Handberg - """ - - if isinstance(radius, (float, int)): - radius *= u.deg - - # Query the SkyMapper cone-search API: - params = { - 'RA': coo_centre.icrs.ra.deg, - 'DEC': coo_centre.icrs.dec.deg, - 'SR': Angle(radius).deg, - 'CATALOG': 'dr2.master', - 'VERB': 1, - 'RESPONSEFORMAT': 'VOTABLE' - } - res = requests.get('http://skymapper.anu.edu.au/sm-cone/public/query', params=params) - res.raise_for_status() - - # For some reason the VOTable parser needs a file-like object: - fid = BytesIO(bytes(res.text, 'utf8')) - results = Table.read(fid, format='votable') - - if len(results) == 0: - return None, None - - # Clean the results: - # http://skymapper.anu.edu.au/data-release/dr2/#Access - indx = (results['flags'] == 0) & (results['nimaflags'] == 0) & (results['ngood'] > 1) - results = results[indx] - if len(results) == 0: - return None, None - - # Create SkyCoord object containing SkyMapper objects with their observation time: - skymapper = SkyCoord( - ra=results['raj2000'], - dec=results['dej2000'], - obstime=Time(results['mean_epoch'], format='mjd', scale='utc'), - frame='icrs') - - return results, skymapper - -#-------------------------------------------------------------------------------------------------- -def query_all(coo_centre, radius=24*u.arcmin, dist_cutoff=2*u.arcsec): - """ - Query all catalogs (REFCAT2, APASS, SDSS and SkyMapper) and return merged catalog. - - Merging of catalogs are done using sky coordinates: - https://docs.astropy.org/en/stable/coordinates/matchsep.html#matching-catalogs - - Parameters: - coo_centre (:class:`astropy.coordinates.SkyCoord`): Coordinates of centre of search cone. - radius (float): Search radius. Default 24 arcmin. - dist_cutoff (float): Maximal distance between object is catalog matching. Default 2 arcsec. - - Returns: - :class:`astropy.table.Table`: Table with catalog stars. - - TODO: - - Use the overlapping magnitudes to make better matching. - - .. codeauthor:: Rasmus Handberg - .. codeauthor:: Emir Karamehmetoglu - """ - - # Query the REFCAT2 catalog using CasJobs around the target position: - results = query_casjobs_refcat2(coo_centre, radius=radius) - AT_results = Table(results) - refcat = SkyCoord(ra=AT_results['ra'], dec=AT_results['decl'], unit=u.deg, frame='icrs') - - # REFCAT results table does not have uBV - N = len(AT_results) - d = np.full(N, np.NaN) - AT_results.add_column(MaskedColumn(name='B_mag', unit='mag', dtype='float32', fill_value=np.NaN, data=d)) - AT_results.add_column(MaskedColumn(name='V_mag', unit='mag', dtype='float32', fill_value=np.NaN, data=d)) - AT_results.add_column(MaskedColumn(name='u_mag', unit='mag', dtype='float32', fill_value=np.NaN, data=d)) - - # Query APASS around the target position: - results_apass = query_apass(coo_centre, radius=radius) - if results_apass: - AT_apass = Table(results_apass) - apass = SkyCoord(ra=AT_apass['ra'], dec=AT_apass['decl'], unit=u.deg, frame='icrs') - - # Match the two catalogs using coordinates: - idx, d2d, _ = apass.match_to_catalog_sky(refcat) - sep_constraint = (d2d <= dist_cutoff) # Reject any match further away than the cutoff - idx_apass = np.arange(len(idx), dtype='int') # since idx maps apass to refcat - - # Update results table with APASS bands of interest - AT_results['B_mag'][idx[sep_constraint]] = AT_apass[idx_apass[sep_constraint]]['B_mag'] - AT_results['V_mag'][idx[sep_constraint]] = AT_apass[idx_apass[sep_constraint]]['V_mag'] - AT_results['u_mag'][idx[sep_constraint]] = AT_apass[idx_apass[sep_constraint]]['u_mag'] - - # Create SDSS around the target position: - AT_sdss, sdss = query_sdss(coo_centre, radius=radius) - if AT_sdss: - # Match to dist_cutoff sky distance (angular) apart - idx, d2d, _ = sdss.match_to_catalog_sky(refcat) - sep_constraint = (d2d <= dist_cutoff) - idx_sdss = np.arange(len(idx), dtype='int') # since idx maps sdss to refcat - - # Overwrite APASS u-band with SDSS u-band: - AT_results['u_mag'][idx[sep_constraint]] = AT_sdss[idx_sdss[sep_constraint]]['psfMag_u'] - - # Query SkyMapper around the target position, only if there are missing u-band magnitudes: - if anynan(AT_results['u_mag']): - results_skymapper, skymapper = query_skymapper(coo_centre, radius=radius) - if results_skymapper: - idx, d2d, _ = skymapper.match_to_catalog_sky(refcat) - sep_constraint = (d2d <= dist_cutoff) - idx_skymapper = np.arange(len(idx), dtype='int') # since idx maps skymapper to refcat - - newval = results_skymapper[idx_skymapper[sep_constraint]]['u_psf'] - oldval = AT_results['u_mag'][idx[sep_constraint]] - indx = ~np.isfinite(oldval) - if np.any(indx): - AT_results['u_mag'][idx[sep_constraint]][indx] = newval[indx] - - return AT_results - -#-------------------------------------------------------------------------------------------------- + logger = logging.getLogger(__name__) + + # Just put in a stop criterion to avoid infinite recursion: + if radius < 0.04 * u.deg: + raise Exception("Too many subdivides") + + # Search central cone: + try: + results = _query_casjobs_refcat2(coo_centre, radius=0.5 * radius) + except CasjobsMemoryError: + logger.debug("CasJobs failed with memory error. Trying to use smaller radii.") + results = _query_casjobs_refcat2_divide_and_conquer(coo_centre, radius=0.5 * radius) + + # Search six cones around central cone: + for n in range(6): + # FIXME: The 0.8 here is kind of a guess. There should be an analytic solution + new = SkyCoord(ra=coo_centre.ra.deg + 0.8 * Angle(radius).deg * np.cos(n * 60 * np.pi / 180), + dec=coo_centre.dec.deg + 0.8 * Angle(radius).deg * np.sin(n * 60 * np.pi / 180), unit='deg', + frame='icrs') + + try: + results += _query_casjobs_refcat2(new, radius=0.5 * radius) + except CasjobsMemoryError: + logger.debug("CasJobs failed with memory error. Trying to use smaller radii.") + results += _query_casjobs_refcat2_divide_and_conquer(new, radius=0.5 * radius) + + return results + + +# -------------------------------------------------------------------------------------------------- +def _query_casjobs_refcat2(coo_centre, radius=24 * u.arcmin): + """ + Uses the CasJobs program to do a cone-search around the position. + + Parameters: + coo_centre (:class:`astropy.coordinates.SkyCoord`): Coordinates of centre of search cone. + radius (Angle, optional): Search radius. Default is 24 arcmin. + + Returns: + list: List of dicts with the REFCAT2 information. + + .. codeauthor:: Rasmus Handberg + """ + + logger = logging.getLogger(__name__) + if isinstance(radius, (float, int)): + radius *= u.deg + + sql = "SELECT r.* FROM fGetNearbyObjEq({ra:f}, {dec:f}, {radius:f}) AS n INNER JOIN HLSP_ATLAS_REFCAT2.refcat2 AS r ON n.objid=r.objid ORDER BY n.distance;".format( + ra=coo_centre.ra.deg, dec=coo_centre.dec.deg, radius=Angle(radius).deg) + logger.debug(sql) + + # Make sure that CasJobs have been configured: + configure_casjobs() + + # The command to run the casjobs script: + # BEWARE: This may change in the future without warning - it has before! + cmd = 'java -jar casjobs.jar execute "{0:s}"'.format(sql) + + # Execute the command: + cmd = shlex.split(cmd) + directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'casjobs') + proc = subprocess.Popen(cmd, cwd=directory, stdout=subprocess.PIPE, universal_newlines=True) + stdout, stderr = proc.communicate() + output = stdout.split("\n") + + # build list of all kois from output from the CasJobs-script: + error_thrown = False + results = [] + for line in output: + line = line.strip() + if line == '': + continue + if 'ERROR' in line: + error_thrown = True + break + + row = line.split(',') + if len(row) == 45 and row[0] != '[objid]:Integer': + results.append( + {'starid': int(row[0]), 'ra': floatval(row[1]), 'decl': floatval(row[2]), 'pm_ra': floatval(row[5]), + 'pm_dec': floatval(row[7]), 'gaia_mag': floatval(row[9]), 'gaia_bp_mag': floatval(row[11]), + 'gaia_rp_mag': floatval(row[13]), 'gaia_variability': intval(row[17]), 'g_mag': floatval(row[22]), + 'r_mag': floatval(row[26]), 'i_mag': floatval(row[30]), 'z_mag': floatval(row[34]), + 'J_mag': floatval(row[39]), 'H_mag': floatval(row[41]), 'K_mag': floatval(row[43]), }) + + if error_thrown: + error_msg = '' + for line in output: + if len(line.strip()) > 0: + error_msg += line.strip() + "\n" + + logger.debug("Error Msg: %s", error_msg) + if 'query results exceed memory limit' in error_msg.lower(): + raise CasjobsMemoryError("Query results exceed memory limit") + else: + raise CasjobsError("ERROR detected in CasJobs: " + error_msg) + + if not results: + raise CasjobsError("Could not find anything on CasJobs") + + logger.debug("Found %d results", len(results)) + return results + + +# -------------------------------------------------------------------------------------------------- +def query_apass(coo_centre, radius=24 * u.arcmin): + """ + Queries APASS catalog using cone-search around the position. + + Parameters: + coo_centre (:class:`astropy.coordinates.SkyCoord`): Coordinates of centre of search cone. + radius (float, optional): + + Returns: + list: List of dicts with the APASS information. + + .. codeauthor:: Rasmus Handberg + """ + + # https://vizier.u-strasbg.fr/viz-bin/VizieR-3?-source=II/336 + + if isinstance(radius, (float, int)): + radius *= u.deg + + data = {'ra': coo_centre.icrs.ra.deg, 'dec': coo_centre.icrs.dec.deg, 'radius': Angle(radius).deg, 'outtype': '1'} + + res = requests.post('https://www.aavso.org/cgi-bin/apass_dr10_download.pl', data=data) + res.raise_for_status() + + results = [] + + lines = res.text.split("\n") + # header = lines[0] + + for line in lines[1:]: + if line.strip() == '': continue + row = line.strip().split(',') + + results.append( + {'ra': floatval(row[0]), 'decl': floatval(row[2]), 'V_mag': floatval(row[4]), 'B_mag': floatval(row[7]), + 'u_mag': floatval(row[10]), 'g_mag': floatval(row[13]), 'r_mag': floatval(row[16]), + 'i_mag': floatval(row[19]), 'z_mag': floatval(row[22]), 'Y_mag': floatval(row[25])}) + + return results + + +# -------------------------------------------------------------------------------------------------- +def query_sdss(coo_centre, radius=24 * u.arcmin, dr=16, clean=True): + """ + Queries SDSS catalog using cone-search around the position. + + Parameters: + coo_centre (:class:`astropy.coordinates.SkyCoord`): Coordinates of centre of search cone. + radius (float, optional): + dr (int, optional): SDSS Data Release to query. Default=16. + clean (bool, optional): Clean results for stars only and no other problems. + + Returns: + tuple: + - :class:`astropy.table.Table`: Table with SDSS information. + - :class:`astropy.coordinates.SkyCoord`: Sky coordinates for SDSS objects. + + .. codeauthor:: Emir Karamehmetoglu + .. codeauthor:: Rasmus Handberg + """ + + if isinstance(radius, (float, int)): + radius *= u.deg + + AT_sdss = SDSS.query_region(coo_centre, photoobj_fields=['type', 'clean', 'ra', 'dec', 'psfMag_u'], data_release=dr, + timeout=600, radius=radius) + + if AT_sdss is None: + return None, None + + if clean: + # Clean SDSS following https://www.sdss.org/dr12/algorithms/photo_flags_recommend/ + # 6 == star, clean means remove interp, edge, suspicious defects, deblending problems, duplicates. + AT_sdss = AT_sdss[(AT_sdss['type'] == 6) & (AT_sdss['clean'] == 1)] + + # Remove these columns since they are no longer needed: + AT_sdss.remove_columns(['type', 'clean']) + + if len(AT_sdss) == 0: + return None, None + + # Create SkyCoord object with the coordinates: + sdss = SkyCoord(ra=AT_sdss['ra'], dec=AT_sdss['dec'], unit=u.deg, frame='icrs') + + return AT_sdss, sdss + + +# -------------------------------------------------------------------------------------------------- +def query_simbad(coo_centre, radius=24 * u.arcmin): + """ + Query SIMBAD using cone-search around the position using astroquery. + + Parameters: + coo_centre (:class:`astropy.coordinates.SkyCoord`): Coordinates of centre of search cone. + radius (float, optional): + + Returns: + list: Astropy Table with SIMBAD information. + + .. codeauthor:: Rasmus Handberg + """ + + s = Simbad() + s.ROW_LIMIT = 0 + s.remove_votable_fields('coordinates') + s.add_votable_fields('ra(d;A;ICRS;J2000)', 'dec(d;D;ICRS;2000)', 'pmra', 'pmdec') + s.add_votable_fields('otype') + s.add_votable_fields('flux(B)', 'flux(V)', 'flux(R)', 'flux(I)', 'flux(J)', 'flux(H)', 'flux(K)') + s.add_votable_fields('flux(u)', 'flux(g)', 'flux(r)', 'flux(i)', 'flux(z)') + + rad = Angle(radius).arcmin + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', category=UserWarning) + results = s.query_criteria( + f'region(circle, icrs, {coo_centre.icrs.ra.deg:.10f} {coo_centre.icrs.dec.deg:+.10f}, {rad}m)', + otypes='Star') + + if not results: + return None, None + + # Rename columns: + results.rename_column('MAIN_ID', 'main_id') + results.rename_column('RA_d_A_ICRS_J2000', 'ra') + results.rename_column('DEC_d_D_ICRS_2000', 'dec') + results.rename_column('PMRA', 'pmra') + results.rename_column('PMDEC', 'pmdec') + results.rename_column('FLUX_B', 'B_mag') + results.rename_column('FLUX_V', 'V_mag') + results.rename_column('FLUX_R', 'R_mag') + results.rename_column('FLUX_I', 'I_mag') + results.rename_column('FLUX_J', 'J_mag') + results.rename_column('FLUX_H', 'H_mag') + results.rename_column('FLUX_K', 'K_mag') + results.rename_column('FLUX_u', 'u_mag') + results.rename_column('FLUX_g', 'g_mag') + results.rename_column('FLUX_r', 'r_mag') + results.rename_column('FLUX_i', 'i_mag') + results.rename_column('FLUX_z', 'z_mag') + results.rename_column('OTYPE', 'otype') + results.remove_column('SCRIPT_NUMBER_ID') + results.sort(['V_mag', 'B_mag', 'H_mag']) + + # Filter out object types which shouldn'r really be in there anyway: + indx = (results['otype'] == 'Galaxy') | (results['otype'] == 'LINER') | (results['otype'] == 'SN') + results = results[~indx] + + if len(results) == 0: + return None, None + + # Build sky coordinates object: + simbad = SkyCoord(ra=results['ra'], dec=results['dec'], pm_ra_cosdec=results['pmra'], pm_dec=results['pmdec'], + frame='icrs', obstime='J2000') + + return results, simbad + + +# -------------------------------------------------------------------------------------------------- +def query_skymapper(coo_centre, radius=24 * u.arcmin): + """ + Queries SkyMapper catalog using cone-search around the position. + + Parameters: + coo_centre (:class:`astropy.coordinates.SkyCoord`): Coordinates of centre of search cone. + radius (float, optional): + + Returns: + tuple: + - :class:`astropy.table.Table`: Astropy Table with SkyMapper information. + - :class:`astropy.coordinates.SkyCoord`: + + .. codeauthor:: Rasmus Handberg + """ + + if isinstance(radius, (float, int)): + radius *= u.deg + + # Query the SkyMapper cone-search API: + params = {'RA': coo_centre.icrs.ra.deg, 'DEC': coo_centre.icrs.dec.deg, 'SR': Angle(radius).deg, + 'CATALOG': 'dr2.master', 'VERB': 1, 'RESPONSEFORMAT': 'VOTABLE'} + res = requests.get('http://skymapper.anu.edu.au/sm-cone/public/query', params=params) + res.raise_for_status() + + # For some reason the VOTable parser needs a file-like object: + fid = BytesIO(bytes(res.text, 'utf8')) + results = Table.read(fid, format='votable') + + if len(results) == 0: + return None, None + + # Clean the results: + # http://skymapper.anu.edu.au/data-release/dr2/#Access + indx = (results['flags'] == 0) & (results['nimaflags'] == 0) & (results['ngood'] > 1) + results = results[indx] + if len(results) == 0: + return None, None + + # Create SkyCoord object containing SkyMapper objects with their observation time: + skymapper = SkyCoord(ra=results['raj2000'], dec=results['dej2000'], + obstime=Time(results['mean_epoch'], format='mjd', scale='utc'), frame='icrs') + + return results, skymapper + + +# -------------------------------------------------------------------------------------------------- +def query_all(coo_centre, radius=24 * u.arcmin, dist_cutoff=2 * u.arcsec): + """ + Query all catalogs (REFCAT2, APASS, SDSS and SkyMapper) and return merged catalog. + + Merging of catalogs are done using sky coordinates: + https://docs.astropy.org/en/stable/coordinates/matchsep.html#matching-catalogs + + Parameters: + coo_centre (:class:`astropy.coordinates.SkyCoord`): Coordinates of centre of search cone. + radius (float): Search radius. Default 24 arcmin. + dist_cutoff (float): Maximal distance between object is catalog matching. Default 2 arcsec. + + Returns: + :class:`astropy.table.Table`: Table with catalog stars. + + TODO: + - Use the overlapping magnitudes to make better matching. + + .. codeauthor:: Rasmus Handberg + .. codeauthor:: Emir Karamehmetoglu + """ + + # Query the REFCAT2 catalog using CasJobs around the target position: + results = query_casjobs_refcat2(coo_centre, radius=radius) + AT_results = Table(results) + refcat = SkyCoord(ra=AT_results['ra'], dec=AT_results['decl'], unit=u.deg, frame='icrs') + + # REFCAT results table does not have uBV + N = len(AT_results) + d = np.full(N, np.NaN) + AT_results.add_column(MaskedColumn(name='B_mag', unit='mag', dtype='float32', fill_value=np.NaN, data=d)) + AT_results.add_column(MaskedColumn(name='V_mag', unit='mag', dtype='float32', fill_value=np.NaN, data=d)) + AT_results.add_column(MaskedColumn(name='u_mag', unit='mag', dtype='float32', fill_value=np.NaN, data=d)) + + # Query APASS around the target position: + results_apass = query_apass(coo_centre, radius=radius) + if results_apass: + AT_apass = Table(results_apass) + apass = SkyCoord(ra=AT_apass['ra'], dec=AT_apass['decl'], unit=u.deg, frame='icrs') + + # Match the two catalogs using coordinates: + idx, d2d, _ = apass.match_to_catalog_sky(refcat) + sep_constraint = (d2d <= dist_cutoff) # Reject any match further away than the cutoff + idx_apass = np.arange(len(idx), dtype='int') # since idx maps apass to refcat + + # Update results table with APASS bands of interest + AT_results['B_mag'][idx[sep_constraint]] = AT_apass[idx_apass[sep_constraint]]['B_mag'] + AT_results['V_mag'][idx[sep_constraint]] = AT_apass[idx_apass[sep_constraint]]['V_mag'] + AT_results['u_mag'][idx[sep_constraint]] = AT_apass[idx_apass[sep_constraint]]['u_mag'] + + # Create SDSS around the target position: + AT_sdss, sdss = query_sdss(coo_centre, radius=radius) + if AT_sdss: + # Match to dist_cutoff sky distance (angular) apart + idx, d2d, _ = sdss.match_to_catalog_sky(refcat) + sep_constraint = (d2d <= dist_cutoff) + idx_sdss = np.arange(len(idx), dtype='int') # since idx maps sdss to refcat + + # Overwrite APASS u-band with SDSS u-band: + AT_results['u_mag'][idx[sep_constraint]] = AT_sdss[idx_sdss[sep_constraint]]['psfMag_u'] + + # Query SkyMapper around the target position, only if there are missing u-band magnitudes: + if anynan(AT_results['u_mag']): + results_skymapper, skymapper = query_skymapper(coo_centre, radius=radius) + if results_skymapper: + idx, d2d, _ = skymapper.match_to_catalog_sky(refcat) + sep_constraint = (d2d <= dist_cutoff) + idx_skymapper = np.arange(len(idx), dtype='int') # since idx maps skymapper to refcat + + newval = results_skymapper[idx_skymapper[sep_constraint]]['u_psf'] + oldval = AT_results['u_mag'][idx[sep_constraint]] + indx = ~np.isfinite(oldval) + if np.any(indx): + AT_results['u_mag'][idx[sep_constraint]][indx] = newval[indx] + + return AT_results + + +# -------------------------------------------------------------------------------------------------- def convert_table_to_dict(tab): - """ - Utility function for converting Astropy Table to list of dicts that the database - likes as input. - - Parameters: - tab (:class:`astropy.table.Table`): Astropy table coming from query_all. - - Returns: - list: List of dicts where the column names are the keys. Datatypes are changed - to things that the database understands (e.g. NaN -> None). - - .. codeauthor:: Rasmus Handberg - """ - results = [dict(zip(tab.colnames, row)) for row in tab.filled()] - for row in results: - for key, val in row.items(): - if isinstance(val, (np.int64, np.int32)): - row[key] = int(val) - elif isinstance(val, (float, np.float32, np.float64)): - if np.isfinite(val): - row[key] = float(val) - else: - row[key] = None - - return results - -#-------------------------------------------------------------------------------------------------- -def download_catalog(target=None, radius=24*u.arcmin, radius_ztf=3*u.arcsec, - dist_cutoff=2*u.arcsec, update_existing=False): - """ - Download reference star catalogs and save to Flows database. - - Parameters: - target (str or int): Target identifier to download catalog for. - radius (Angle, optional): Radius around target to download catalogs. - radius_ztf (Angle, optional): Radius around target to search for ZTF identifier. - dist_cutoff (Angle, optional): Distance cutoff used for matching catalog positions. - update_existing (bool, optional): Update existing catalog entries or skip them. - - .. codeauthor:: Rasmus Handberg - """ - - logger = logging.getLogger(__name__) - - with AADC_DB() as db: - - # Get the information about the target from the database: - if target is not None and isinstance(target, (int, float)): - db.cursor.execute("SELECT targetid,target_name,ra,decl,discovery_date FROM flows.targets WHERE targetid=%s;", [int(target)]) - elif target is not None: - db.cursor.execute("SELECT targetid,target_name,ra,decl,discovery_date FROM flows.targets WHERE target_name=%s;", [target]) - else: - db.cursor.execute("SELECT targetid,target_name,ra,decl,discovery_date FROM flows.targets WHERE catalog_downloaded=FALSE;") - - for row in db.cursor.fetchall(): - # The unique identifier of the target: - targetid = int(row['targetid']) - target_name = row['target_name'] - dd = row['discovery_date'] - if dd is not None: - dd = Time(dd, format='datetime', scale='utc') - - # Coordinate of the target, which is the centre of the search cone: - coo_centre = SkyCoord(ra=row['ra'], dec=row['decl'], unit=u.deg, frame='icrs') - - # Download combined catalog from all sources: - tab = query_all(coo_centre, radius=radius, dist_cutoff=dist_cutoff) - - # Query for a ZTF identifier for this target: - ztf_id = query_ztf_id(coo_centre, radius=radius_ztf, discovery_date=dd) - - # Because the database is picky with datatypes, we need to change things - # before they are passed on to the database: - results = convert_table_to_dict(tab) - - # Insert the catalog into the local database: - if update_existing: - on_conflict = """ON CONSTRAINT refcat2_pkey DO UPDATE SET + """ + Utility function for converting Astropy Table to list of dicts that the database + likes as input. + + Parameters: + tab (:class:`astropy.table.Table`): Astropy table coming from query_all. + + Returns: + list: List of dicts where the column names are the keys. Datatypes are changed + to things that the database understands (e.g. NaN -> None). + + .. codeauthor:: Rasmus Handberg + """ + results = [dict(zip(tab.colnames, row)) for row in tab.filled()] + for row in results: + for key, val in row.items(): + if isinstance(val, (np.int64, np.int32)): + row[key] = int(val) + elif isinstance(val, (float, np.float32, np.float64)): + if np.isfinite(val): + row[key] = float(val) + else: + row[key] = None + + return results + + +# -------------------------------------------------------------------------------------------------- +def download_catalog(target=None, radius=24 * u.arcmin, radius_ztf=3 * u.arcsec, dist_cutoff=2 * u.arcsec, + update_existing=False): + """ + Download reference star catalogs and save to Flows database. + + Parameters: + target (str or int): Target identifier to download catalog for. + radius (Angle, optional): Radius around target to download catalogs. + radius_ztf (Angle, optional): Radius around target to search for ZTF identifier. + dist_cutoff (Angle, optional): Distance cutoff used for matching catalog positions. + update_existing (bool, optional): Update existing catalog entries or skip them. + + .. codeauthor:: Rasmus Handberg + """ + + logger = logging.getLogger(__name__) + + with AADC_DB() as db: + + # Get the information about the target from the database: + if target is not None and isinstance(target, (int, float)): + db.cursor.execute( + "SELECT targetid,target_name,ra,decl,discovery_date FROM flows.targets WHERE targetid=%s;", + [int(target)]) + elif target is not None: + db.cursor.execute( + "SELECT targetid,target_name,ra,decl,discovery_date FROM flows.targets WHERE target_name=%s;", [target]) + else: + db.cursor.execute( + "SELECT targetid,target_name,ra,decl,discovery_date FROM flows.targets WHERE catalog_downloaded=FALSE;") + + for row in db.cursor.fetchall(): + # The unique identifier of the target: + targetid = int(row['targetid']) + target_name = row['target_name'] + dd = row['discovery_date'] + if dd is not None: + dd = Time(dd, format='datetime', scale='utc') + + # Coordinate of the target, which is the centre of the search cone: + coo_centre = SkyCoord(ra=row['ra'], dec=row['decl'], unit=u.deg, frame='icrs') + + # Download combined catalog from all sources: + tab = query_all(coo_centre, radius=radius, dist_cutoff=dist_cutoff) + + # Query for a ZTF identifier for this target: + ztf_id = query_ztf_id(coo_centre, radius=radius_ztf, discovery_date=dd) + + # Because the database is picky with datatypes, we need to change things + # before they are passed on to the database: + results = convert_table_to_dict(tab) + + # Insert the catalog into the local database: + if update_existing: + on_conflict = """ON CONSTRAINT refcat2_pkey DO UPDATE SET ra=EXCLUDED.ra, decl=EXCLUDED.decl, pm_ra=EXCLUDED.pm_ra, @@ -650,11 +621,11 @@ def download_catalog(target=None, radius=24*u.arcmin, radius_ztf=3*u.arcsec, "V_mag"=EXCLUDED."V_mag", "B_mag"=EXCLUDED."B_mag" WHERE refcat2.starid=EXCLUDED.starid""" - else: - on_conflict = 'DO NOTHING' + else: + on_conflict = 'DO NOTHING' - try: - db.cursor.executemany("""INSERT INTO flows.refcat2 ( + try: + db.cursor.executemany("""INSERT INTO flows.refcat2 ( starid, ra, decl, @@ -695,11 +666,12 @@ def download_catalog(target=None, radius=24*u.arcmin, radius_ztf=3*u.arcsec, %(V_mag)s, %(B_mag)s) ON CONFLICT """ + on_conflict + ";", results) - logger.info("%d catalog entries inserted for %s.", db.cursor.rowcount, target_name) - - # Mark the target that the catalog has been downloaded: - db.cursor.execute("UPDATE flows.targets SET catalog_downloaded=TRUE,ztf_id=%s WHERE targetid=%s;", (ztf_id, targetid)) - db.conn.commit() - except: # noqa: E722, pragma: no cover - db.conn.rollback() - raise + logger.info("%d catalog entries inserted for %s.", db.cursor.rowcount, target_name) + + # Mark the target that the catalog has been downloaded: + db.cursor.execute("UPDATE flows.targets SET catalog_downloaded=TRUE,ztf_id=%s WHERE targetid=%s;", + (ztf_id, targetid)) + db.conn.commit() + except: # noqa: E722, pragma: no cover + db.conn.rollback() + raise diff --git a/flows/config.py b/flows/config.py index eb9d8af..846b890 100644 --- a/flows/config.py +++ b/flows/config.py @@ -9,22 +9,23 @@ import configparser from functools import lru_cache -#-------------------------------------------------------------------------------------------------- + +# -------------------------------------------------------------------------------------------------- @lru_cache(maxsize=1) def load_config(): - """ - Load configuration file. + """ + Load configuration file. - Returns: - ``configparser.ConfigParser``: Configuration file. + Returns: + ``configparser.ConfigParser``: Configuration file. - .. codeauthor:: Rasmus Handberg - """ + .. codeauthor:: Rasmus Handberg + """ - config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'config.ini') - if not os.path.isfile(config_file): - raise FileNotFoundError("config.ini file not found") + config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'config.ini') + if not os.path.isfile(config_file): + raise FileNotFoundError("config.ini file not found") - config = configparser.ConfigParser() - config.read(config_file) - return config + config = configparser.ConfigParser() + config.read(config_file) + return config diff --git a/flows/coordinatematch/coordinatematch.py b/flows/coordinatematch/coordinatematch.py index 97a2acc..55f73eb 100644 --- a/flows/coordinatematch/coordinatematch.py +++ b/flows/coordinatematch/coordinatematch.py @@ -13,378 +13,331 @@ from networkx import Graph, connected_components from .wcs import WCS2 -#-------------------------------------------------------------------------------------------------- + +# -------------------------------------------------------------------------------------------------- class CoordinateMatch(object): - def __init__(self, - xy, - rd, - xy_order=None, - rd_order=None, - xy_nmax=None, - rd_nmax=None, - n_triangle_packages=10, - triangle_package_size=10000, - maximum_angle_distance=0.001, - distance_factor=1): - - self.xy, self.rd = np.array(xy), np.array(rd) - - self._xy = xy - np.mean(xy, axis=0) - self._rd = rd - np.mean(rd, axis=0) - self._rd[:, 0] *= np.cos(np.deg2rad(self.rd[:, 1])) - - xy_n, rd_n = min(xy_nmax, len(xy)), min(rd_nmax, len(rd)) - - self.i_xy = xy_order[:xy_n] if xy_order is not None else np.arange( - xy_n) - self.i_rd = rd_order[:rd_n] if rd_order is not None else np.arange( - rd_n) - - self.n_triangle_packages = n_triangle_packages - self.triangle_package_size = triangle_package_size - - self.maximum_angle_distance = maximum_angle_distance - self.distance_factor = distance_factor - - self.triangle_package_generator = self._sorted_triangle_packages() - - self.i_xy_triangles = list() - self.i_rd_triangles = list() - self.parameters = None - self.neighbours = Graph() - - self.normalizations = type( - 'Normalizations', (object, ), - dict(ra=0.0001, dec=0.0001, scale=0.002, angle=0.002)) - - self.bounds = type( - 'Bounds', (object, ), - dict(xy=self.xy.mean(axis=0), - rd=None, - radius=None, - scale=None, - angle=None)) - - #---------------------------------------------------------------------------------------------- - def set_normalizations(self, ra=None, dec=None, scale=None, angle=None): - """ - Set normalization factors in the (ra, dec, scale, angle) space. - - Defaults are: - ra = 0.0001 degrees - dec = 0.0001 degrees - scale = 0.002 log(arcsec/pixel) - angle = 0.002 radians - """ - - if self.parameters is not None: - raise RuntimeError("can't change normalization after matching is started") - - # TODO: Dont use "assert" here - raise ValueError instead - assert ra is None or 0 < ra - assert dec is None or 0 < dec - assert scale is None or 0 < scale - assert angle is None or 0 < angle - - self.normalizations.ra = ra if ra is not None else self.normalizations.ra - self.normalizations.dec = dec if dec is not None else self.normalizations.dec - self.normalizations.scale = scale if scale is not None else self.normalizations.scale - self.normalizations.angle = angle if ra is not None else self.normalizations.angle - - #---------------------------------------------------------------------------------------------- - def set_bounds(self, x=None, y=None, ra=None, dec=None, radius=None, scale=None, angle=None): - """ - Set bounds for what are valid results. - - Set x, y, ra, dec and radius to specify that the x, y coordinates must be no - further that the radius [degrees] away from the ra, dec coordinates. - Set upper and lower bounds on the scale [log(arcsec/pixel)] and/or the angle - [radians] if those are known, possibly from previous observations with the - same system. - """ - - if self.parameters is not None: - raise RuntimeError("can't change bounds after matching is started") - - if [x, y, ra, dec, radius].count(None) == 5: - # TODO: Dont use "assert" here - raise ValueError instead - assert 0 <= ra < 360 - assert -180 <= dec <= 180 - assert 0 < radius - - self.bounds.xy = x, y - self.bounds.rd = ra, dec - self.bounds.radius = radius - - elif [x, y, ra, dec, radius].count(None) > 0: - raise ValueError('x, y, ra, dec and radius must all be specified') - - # TODO: Dont use "assert" here - raise ValueError instead - assert scale is None or 0 < scale[0] < scale[1] - assert angle is None or -np.pi <= angle[0] < angle[1] <= np.pi - - self.bounds.scale = scale if scale is not None else self.bounds.scale - self.bounds.angle = angle if angle is not None else self.bounds.angle + def __init__(self, xy, rd, xy_order=None, rd_order=None, xy_nmax=None, rd_nmax=None, n_triangle_packages=10, + triangle_package_size=10000, maximum_angle_distance=0.001, distance_factor=1): + + self.xy, self.rd = np.array(xy), np.array(rd) + + self._xy = xy - np.mean(xy, axis=0) + self._rd = rd - np.mean(rd, axis=0) + self._rd[:, 0] *= np.cos(np.deg2rad(self.rd[:, 1])) + + xy_n, rd_n = min(xy_nmax, len(xy)), min(rd_nmax, len(rd)) + + self.i_xy = xy_order[:xy_n] if xy_order is not None else np.arange(xy_n) + self.i_rd = rd_order[:rd_n] if rd_order is not None else np.arange(rd_n) + + self.n_triangle_packages = n_triangle_packages + self.triangle_package_size = triangle_package_size + + self.maximum_angle_distance = maximum_angle_distance + self.distance_factor = distance_factor + + self.triangle_package_generator = self._sorted_triangle_packages() + + self.i_xy_triangles = list() + self.i_rd_triangles = list() + self.parameters = None + self.neighbours = Graph() + + self.normalizations = type('Normalizations', (object,), dict(ra=0.0001, dec=0.0001, scale=0.002, angle=0.002)) + + self.bounds = type('Bounds', (object,), + dict(xy=self.xy.mean(axis=0), rd=None, radius=None, scale=None, angle=None)) + + # ---------------------------------------------------------------------------------------------- + def set_normalizations(self, ra=None, dec=None, scale=None, angle=None): + """ + Set normalization factors in the (ra, dec, scale, angle) space. + + Defaults are: + ra = 0.0001 degrees + dec = 0.0001 degrees + scale = 0.002 log(arcsec/pixel) + angle = 0.002 radians + """ + + if self.parameters is not None: + raise RuntimeError("can't change normalization after matching is started") + + # TODO: Dont use "assert" here - raise ValueError instead + assert ra is None or 0 < ra + assert dec is None or 0 < dec + assert scale is None or 0 < scale + assert angle is None or 0 < angle + + self.normalizations.ra = ra if ra is not None else self.normalizations.ra + self.normalizations.dec = dec if dec is not None else self.normalizations.dec + self.normalizations.scale = scale if scale is not None else self.normalizations.scale + self.normalizations.angle = angle if ra is not None else self.normalizations.angle + + # ---------------------------------------------------------------------------------------------- + def set_bounds(self, x=None, y=None, ra=None, dec=None, radius=None, scale=None, angle=None): + """ + Set bounds for what are valid results. + + Set x, y, ra, dec and radius to specify that the x, y coordinates must be no + further that the radius [degrees] away from the ra, dec coordinates. + Set upper and lower bounds on the scale [log(arcsec/pixel)] and/or the angle + [radians] if those are known, possibly from previous observations with the + same system. + """ + + if self.parameters is not None: + raise RuntimeError("can't change bounds after matching is started") + + if [x, y, ra, dec, radius].count(None) == 5: + # TODO: Dont use "assert" here - raise ValueError instead + assert 0 <= ra < 360 + assert -180 <= dec <= 180 + assert 0 < radius + + self.bounds.xy = x, y + self.bounds.rd = ra, dec + self.bounds.radius = radius + + elif [x, y, ra, dec, radius].count(None) > 0: + raise ValueError('x, y, ra, dec and radius must all be specified') + + # TODO: Dont use "assert" here - raise ValueError instead + assert scale is None or 0 < scale[0] < scale[1] + assert angle is None or -np.pi <= angle[0] < angle[1] <= np.pi + + self.bounds.scale = scale if scale is not None else self.bounds.scale + self.bounds.angle = angle if angle is not None else self.bounds.angle + + # ---------------------------------------------------------------------------------------------- + def _sorted_triangles(self, pool): + for i, c in enumerate(pool): + for i, b in enumerate(pool[:i]): + for a in pool[:i]: + yield a, b, c + + # ---------------------------------------------------------------------------------------------- + def _sorted_product_pairs(self, p, q): + i_p = np.argsort(np.arange(len(p))) + i_q = np.argsort(np.arange(len(q))) + for _i_p, _i_q in sorted(product(i_p, i_q), key=lambda idxs: sum(idxs)): + yield p[_i_p], q[_i_q] + + # ---------------------------------------------------------------------------------------------- + def _sorted_triangle_packages(self): + + i_xy_triangle_generator = self._sorted_triangles(self.i_xy) + i_rd_triangle_generator = self._sorted_triangles(self.i_rd) + + i_xy_triangle_slice_generator = (tuple(islice(i_xy_triangle_generator, self.triangle_package_size)) for _ in + count()) + i_rd_triangle_slice_generator = (list(islice(i_rd_triangle_generator, self.triangle_package_size)) for _ in + count()) + + for n in count(step=self.n_triangle_packages): + + i_xy_triangle_slice = tuple(filter(None, islice(i_xy_triangle_slice_generator, self.n_triangle_packages))) + i_rd_triangle_slice = tuple(filter(None, islice(i_rd_triangle_slice_generator, self.n_triangle_packages))) + + if not len(i_xy_triangle_slice) and not len(i_rd_triangle_slice): + return + + i_xy_triangle_generator2 = self._sorted_triangles(self.i_xy) + i_rd_triangle_generator2 = self._sorted_triangles(self.i_rd) + + i_xy_triangle_cum = filter(None, + (tuple(islice(i_xy_triangle_generator2, self.triangle_package_size)) for _ in + range(n))) + i_rd_triangle_cum = filter(None, + (tuple(islice(i_rd_triangle_generator2, self.triangle_package_size)) for _ in + range(n))) + + for i_xy_triangles, i_rd_triangles in chain(filter(None, chain(*zip_longest( # alternating chain + product(i_xy_triangle_slice, i_rd_triangle_cum), product(i_xy_triangle_cum, i_rd_triangle_slice)))), + self._sorted_product_pairs(i_xy_triangle_slice, + i_rd_triangle_slice)): + yield np.array(i_xy_triangles), np.array(i_rd_triangles) + + # ---------------------------------------------------------------------------------------------- + def _get_triangle_angles(self, triangles): - #---------------------------------------------------------------------------------------------- - def _sorted_triangles(self, pool): - for i, c in enumerate(pool): - for i, b in enumerate(pool[:i]): - for a in pool[:i]: + sidelengths = np.sqrt(np.power(triangles[:, (1, 0, 0)] - triangles[:, (2, 2, 1)], 2).sum(axis=2)) - yield a, b, c + # law of cosines + angles = np.power(sidelengths[:, ((1, 2), (0, 2), (0, 1))], 2).sum(axis=2) + angles -= np.power(sidelengths[:, (0, 1, 2)], 2) + angles /= 2 * sidelengths[:, ((1, 2), (0, 2), (0, 1))].prod(axis=2) - #---------------------------------------------------------------------------------------------- - def _sorted_product_pairs(self, p, q): - i_p = np.argsort(np.arange(len(p))) - i_q = np.argsort(np.arange(len(q))) - for _i_p, _i_q in sorted(product(i_p, i_q), key=lambda idxs: sum(idxs)): - yield p[_i_p], q[_i_q] - - #---------------------------------------------------------------------------------------------- - def _sorted_triangle_packages(self): - - i_xy_triangle_generator = self._sorted_triangles(self.i_xy) - i_rd_triangle_generator = self._sorted_triangles(self.i_rd) - - i_xy_triangle_slice_generator = (tuple( - islice(i_xy_triangle_generator, self.triangle_package_size)) for _ in count()) - i_rd_triangle_slice_generator = (list( - islice(i_rd_triangle_generator, self.triangle_package_size)) for _ in count()) + return np.arccos(angles) - for n in count(step=self.n_triangle_packages): - - i_xy_triangle_slice = tuple( - filter( - None, - islice(i_xy_triangle_slice_generator, - self.n_triangle_packages))) - i_rd_triangle_slice = tuple( - filter( - None, - islice(i_rd_triangle_slice_generator, - self.n_triangle_packages))) + # ---------------------------------------------------------------------------------------------- + def _solve_for_matrices(self, xy_triangles, rd_triangles): - if not len(i_xy_triangle_slice) and not len(i_rd_triangle_slice): - return + n = len(xy_triangles) - i_xy_triangle_generator2 = self._sorted_triangles(self.i_xy) - i_rd_triangle_generator2 = self._sorted_triangles(self.i_rd) - - i_xy_triangle_cum = filter(None, (tuple( - islice(i_xy_triangle_generator2, self.triangle_package_size)) for _ in range(n))) - i_rd_triangle_cum = filter(None, (tuple( - islice(i_rd_triangle_generator2, self.triangle_package_size)) for _ in range(n))) + A = xy_triangles - np.mean(xy_triangles, axis=1).reshape(n, 1, 2) + b = rd_triangles - np.mean(rd_triangles, axis=1).reshape(n, 1, 2) - for i_xy_triangles, i_rd_triangles in chain(filter(None, - chain(*zip_longest( # alternating chain - product(i_xy_triangle_slice, i_rd_triangle_cum), - product(i_xy_triangle_cum, i_rd_triangle_slice)))), - self._sorted_product_pairs(i_xy_triangle_slice, i_rd_triangle_slice)): - - yield np.array(i_xy_triangles), np.array(i_rd_triangles) - - #---------------------------------------------------------------------------------------------- - def _get_triangle_angles(self, triangles): - - sidelengths = np.sqrt( - np.power(triangles[:, (1, 0, 0)] - triangles[:, (2, 2, 1)], - 2).sum(axis=2)) + matrices = [np.linalg.lstsq(Ai, bi, rcond=None)[0].T for Ai, bi in zip(A, b)] - # law of cosines - angles = np.power(sidelengths[:, ((1, 2), (0, 2), (0, 1))], - 2).sum(axis=2) - angles -= np.power(sidelengths[:, (0, 1, 2)], 2) - angles /= 2 * sidelengths[:, ((1, 2), (0, 2), (0, 1))].prod(axis=2) + return np.array(matrices) - return np.arccos(angles) + # ---------------------------------------------------------------------------------------------- + def _extract_parameters(self, xy_triangles, rd_triangles, matrices): - #---------------------------------------------------------------------------------------------- - def _solve_for_matrices(self, xy_triangles, rd_triangles): + parameters = [] + for xy_com, rd_com, matrix in zip(xy_triangles.mean(axis=1), rd_triangles.mean(axis=1), matrices): + # com -> center-of-mass - n = len(xy_triangles) + cos_dec = np.cos(np.deg2rad(rd_com[1])) + coordinates = (self.bounds.xy - xy_com).dot(matrix) + coordinates = coordinates / np.array([cos_dec, 1]) + rd_com - A = xy_triangles - np.mean(xy_triangles, axis=1).reshape(n, 1, 2) - b = rd_triangles - np.mean(rd_triangles, axis=1).reshape(n, 1, 2) + wcs = WCS2.from_matrix(*xy_com, *rd_com, matrix) - matrices = [ - np.linalg.lstsq(Ai, bi, rcond=None)[0].T for Ai, bi in zip(A, b) - ] + parameters.append((*coordinates, np.log(wcs.scale), np.deg2rad(wcs.angle))) - return np.array(matrices) + return parameters - #---------------------------------------------------------------------------------------------- - def _extract_parameters(self, xy_triangles, rd_triangles, matrices): + # ---------------------------------------------------------------------------------------------- + def _get_bounds_mask(self, parameters): - parameters = [] - for xy_com, rd_com, matrix in zip(xy_triangles.mean(axis=1), rd_triangles.mean(axis=1), matrices): - # com -> center-of-mass + i = np.ones(len(parameters), dtype=bool) + parameters = np.array(parameters) - cos_dec = np.cos(np.deg2rad(rd_com[1])) - coordinates = (self.bounds.xy - xy_com).dot(matrix) - coordinates = coordinates / np.array([cos_dec, 1]) + rd_com + if self.bounds.radius is not None: + i *= angular_separation(*np.deg2rad(self.bounds.rd), + *zip(*np.deg2rad(parameters[:, (0, 1)]))) <= np.deg2rad(self.bounds.radius) - wcs = WCS2.from_matrix(*xy_com, *rd_com, matrix) + if self.bounds.scale is not None: + i *= self.bounds.scale[0] <= parameters[:, 2] + i *= parameters[:, 2] <= self.bounds.scale[1] - parameters.append((*coordinates, np.log(wcs.scale), np.deg2rad(wcs.angle))) + if self.bounds.angle is not None: + i *= self.bounds.angle[0] <= parameters[:, 3] + i *= parameters[:, 3] <= self.bounds.angle[1] - return parameters + return i - #---------------------------------------------------------------------------------------------- - def _get_bounds_mask(self, parameters): + # ---------------------------------------------------------------------------------------------- + def __call__(self, minimum_matches=4, ratio_superiority=1, timeout=60): + """ + Start the alogrithm. - i = np.ones(len(parameters), dtype=bool) - parameters = np.array(parameters) + Can be run multiple times with different arguments to relax the + restrictions. - if self.bounds.radius is not None: - i *= angular_separation( - *np.deg2rad(self.bounds.rd), - *zip(*np.deg2rad(parameters[:, (0, 1)])) - ) <= np.deg2rad(self.bounds.radius) + Example + -------- + cm = CoordinateMatch(xy, rd) - if self.bounds.scale is not None: - i *= self.bounds.scale[0] <= parameters[:, 2] - i *= parameters[:, 2] <= self.bounds.scale[1] + lkwargs = [{ + minimum_matches = 20, + ratio_superiority = 5, + timeout = 10 + },{ + timeout = 60 + } - if self.bounds.angle is not None: - i *= self.bounds.angle[0] <= parameters[:, 3] - i *= parameters[:, 3] <= self.bounds.angle[1] + for i, kwargs in enumerate(lkwargs): + try: + i_xy, i_rd = cm(**kwargs) + except TimeoutError: + continue + except StopIteration: + print('Failed, no more stars.') + else: + print('Success with kwargs[%d].' % i) + else: + print('Failed, timeout.') + """ - return i + self.parameters = list() if self.parameters is None else self.parameters - #---------------------------------------------------------------------------------------------- - def __call__(self, minimum_matches=4, ratio_superiority=1, timeout=60): - """ - Start the alogrithm. + t0 = time.time() - Can be run multiple times with different arguments to relax the - restrictions. + while time.time() - t0 < timeout: - Example - -------- - cm = CoordinateMatch(xy, rd) + # get triangles and derive angles - lkwargs = [{ - minimum_matches = 20, - ratio_superiority = 5, - timeout = 10 - },{ - timeout = 60 - } + i_xy_triangles, i_rd_triangles = next(self.triangle_package_generator) - for i, kwargs in enumerate(lkwargs): - try: - i_xy, i_rd = cm(**kwargs) - except TimeoutError: - continue - except StopIteration: - print('Failed, no more stars.') - else: - print('Success with kwargs[%d].' % i) - else: - print('Failed, timeout.') - """ + xy_angles = self._get_triangle_angles(self._xy[i_xy_triangles]) + rd_angles = self._get_triangle_angles(self._rd[i_rd_triangles]) - self.parameters = list() if self.parameters is None else self.parameters + # sort triangle vertices based on angles - t0 = time.time() + i = np.argsort(xy_angles, axis=1) + i_xy_triangles = np.take_along_axis(i_xy_triangles, i, axis=1) + xy_angles = np.take_along_axis(xy_angles, i, axis=1) - while time.time() - t0 < timeout: + i = np.argsort(rd_angles, axis=1) + i_rd_triangles = np.take_along_axis(i_rd_triangles, i, axis=1) + rd_angles = np.take_along_axis(rd_angles, i, axis=1) - # get triangles and derive angles + # match triangles + matches = KDTree(xy_angles).query_ball_tree(KDTree(rd_angles), r=self.maximum_angle_distance) + matches = np.array([(_i_xy, _i_rd) for _i_xy, _li_rd in enumerate(matches) for _i_rd in _li_rd]) - i_xy_triangles, i_rd_triangles = next( - self.triangle_package_generator) + if not len(matches): + continue - xy_angles = self._get_triangle_angles(self._xy[i_xy_triangles]) - rd_angles = self._get_triangle_angles(self._rd[i_rd_triangles]) + i_xy_triangles = list(i_xy_triangles[matches[:, 0]]) + i_rd_triangles = list(i_rd_triangles[matches[:, 1]]) - # sort triangle vertices based on angles + # get parameters of wcs solutions + matrices = self._solve_for_matrices(self._xy[np.array(i_xy_triangles)], self._rd[np.array(i_rd_triangles)]) - i = np.argsort(xy_angles, axis=1) - i_xy_triangles = np.take_along_axis(i_xy_triangles, i, axis=1) - xy_angles = np.take_along_axis(xy_angles, i, axis=1) + parameters = self._extract_parameters(self.xy[np.array(i_xy_triangles)], self.rd[np.array(i_rd_triangles)], + matrices) - i = np.argsort(rd_angles, axis=1) - i_rd_triangles = np.take_along_axis(i_rd_triangles, i, axis=1) - rd_angles = np.take_along_axis(rd_angles, i, axis=1) + # apply bounds if any + if any([self.bounds.radius, self.bounds.scale, self.bounds.angle]): + mask = self._get_bounds_mask(parameters) - # match triangles - matches = KDTree(xy_angles).query_ball_tree( - KDTree(rd_angles), r=self.maximum_angle_distance) - matches = np.array([(_i_xy, _i_rd) for _i_xy, _li_rd in enumerate(matches) for _i_rd in _li_rd]) + i_xy_triangles = np.array(i_xy_triangles)[mask].tolist() + i_rd_triangles = np.array(i_rd_triangles)[mask].tolist() + parameters = np.array(parameters)[mask].tolist() - if not len(matches): - continue + # normalize parameters + normalization = [getattr(self.normalizations, v) for v in ('ra', 'dec', 'scale', 'angle')] + normalization[0] *= np.cos(np.deg2rad(self.rd[:, 1].mean(axis=0))) + parameters = list(parameters / np.array(normalization)) - i_xy_triangles = list(i_xy_triangles[matches[:, 0]]) - i_rd_triangles = list(i_rd_triangles[matches[:, 1]]) + # match parameters + neighbours = KDTree(parameters).query_ball_tree(KDTree(self.parameters + parameters), + r=self.distance_factor) + neighbours = np.array([(i, j) for i, lj in enumerate(neighbours, len(self.parameters)) for j in lj]) + neighbours = list(neighbours[(np.diff(neighbours, axis=1) < 0).flatten()]) - # get parameters of wcs solutions - matrices = self._solve_for_matrices( - self._xy[np.array(i_xy_triangles)], - self._rd[np.array(i_rd_triangles)]) + if not len(neighbours): + continue - parameters = self._extract_parameters( - self.xy[np.array(i_xy_triangles)], - self.rd[np.array(i_rd_triangles)], matrices) - - # apply bounds if any - if any([self.bounds.radius, self.bounds.scale, self.bounds.angle]): - - mask = self._get_bounds_mask(parameters) - - i_xy_triangles = np.array(i_xy_triangles)[mask].tolist() - i_rd_triangles = np.array(i_rd_triangles)[mask].tolist() - parameters = np.array(parameters)[mask].tolist() - - # normalize parameters - normalization = [ - getattr(self.normalizations, v) - for v in ('ra', 'dec', 'scale', 'angle') - ] - normalization[0] *= np.cos(np.deg2rad(self.rd[:, 1].mean(axis=0))) - parameters = list(parameters / np.array(normalization)) + self.i_xy_triangles += i_xy_triangles + self.i_rd_triangles += i_rd_triangles + self.parameters += parameters + self.neighbours.add_edges_from(neighbours) - # match parameters - neighbours = KDTree(parameters).query_ball_tree( - KDTree(self.parameters + parameters), r=self.distance_factor) - neighbours = np.array([ - (i, j) for i, lj in enumerate(neighbours, len(self.parameters)) - for j in lj - ]) - neighbours = list( - neighbours[(np.diff(neighbours, axis=1) < 0).flatten()]) - - if not len(neighbours): - continue - - self.i_xy_triangles += i_xy_triangles - self.i_rd_triangles += i_rd_triangles - self.parameters += parameters - self.neighbours.add_edges_from(neighbours) - - # get largest neighborhood - communities = list(connected_components(self.neighbours)) - c1 = np.array(list(max(communities, key=len))) - i = np.unique(np.array(self.i_xy_triangles)[c1].flatten(), - return_index=True)[1] + # get largest neighborhood + communities = list(connected_components(self.neighbours)) + c1 = np.array(list(max(communities, key=len))) + i = np.unique(np.array(self.i_xy_triangles)[c1].flatten(), return_index=True)[1] - if ratio_superiority > 1 and len(communities) > 1: - communities.remove(set(c1)) - c2 = np.array(list(max(communities, key=len))) - _i = np.unique(np.array(self.i_xy_triangles)[c2].flatten()) + if ratio_superiority > 1 and len(communities) > 1: + communities.remove(set(c1)) + c2 = np.array(list(max(communities, key=len))) + _i = np.unique(np.array(self.i_xy_triangles)[c2].flatten()) - if len(i) / len(_i) < ratio_superiority: - continue + if len(i) / len(_i) < ratio_superiority: + continue - if len(i) >= minimum_matches: - break + if len(i) >= minimum_matches: + break - else: - raise TimeoutError + else: + raise TimeoutError - i_xy = np.array(self.i_xy_triangles)[c1].flatten()[i] - i_rd = np.array(self.i_rd_triangles)[c1].flatten()[i] + i_xy = np.array(self.i_xy_triangles)[c1].flatten()[i] + i_rd = np.array(self.i_rd_triangles)[c1].flatten()[i] - return list(zip(i_xy, i_rd)) + return list(zip(i_xy, i_rd)) diff --git a/flows/coordinatematch/wcs.py b/flows/coordinatematch/wcs.py index 02333d8..f851d72 100644 --- a/flows/coordinatematch/wcs.py +++ b/flows/coordinatematch/wcs.py @@ -11,217 +11,210 @@ from scipy.optimize import minimize from scipy.spatial.transform import Rotation -#-------------------------------------------------------------------------------------------------- + +# -------------------------------------------------------------------------------------------------- class WCS2(): - """ - Manipulate WCS solution. + """ + Manipulate WCS solution. + + Initialize + ---------- + wcs = WCS2(x, y, ra, dec, scale, mirror, angle) + wcs = WCS2.from_matrix(x, y, ra, dec, matrix) + wcs = WCS2.from_points(list(zip(x, y)), list(zip(ra, dec))) + wcs = WCS2.from_astropy_wcs(astropy.wcs.WCS()) - Initialize - ---------- - wcs = WCS2(x, y, ra, dec, scale, mirror, angle) - wcs = WCS2.from_matrix(x, y, ra, dec, matrix) - wcs = WCS2.from_points(list(zip(x, y)), list(zip(ra, dec))) - wcs = WCS2.from_astropy_wcs(astropy.wcs.WCS()) + ra, dec and angle should be in degrees + scale should be in arcsec/pixel + matrix should be the PC or CD matrix - ra, dec and angle should be in degrees - scale should be in arcsec/pixel - matrix should be the PC or CD matrix + Examples + -------- + Adjust x, y offset: + wcs.x += delta_x + wcs.y += delta_y - Examples - -------- - Adjust x, y offset: - wcs.x += delta_x - wcs.y += delta_y + Get scale and angle: + print(wcs.scale, wcs.angle) - Get scale and angle: - print(wcs.scale, wcs.angle) + Change an astropy.wcs.WCS (wcs) angle + wcs = WCS2(wcs)(angle=new_angle).astropy_wcs - Change an astropy.wcs.WCS (wcs) angle - wcs = WCS2(wcs)(angle=new_angle).astropy_wcs + Adjust solution with points + wcs.adjust_with_points(list(zip(x, y)), list(zip(ra, dec))) + """ - Adjust solution with points - wcs.adjust_with_points(list(zip(x, y)), list(zip(ra, dec))) - """ + # ---------------------------------------------------------------------------------------------- + def __init__(self, x, y, ra, dec, scale, mirror, angle): + self.x, self.y = x, y + self.ra, self.dec = ra, dec + self.scale = scale + self.mirror = mirror + self.angle = angle - #---------------------------------------------------------------------------------------------- - def __init__(self, x, y, ra, dec, scale, mirror, angle): - self.x, self.y = x, y - self.ra, self.dec = ra, dec - self.scale = scale - self.mirror = mirror - self.angle = angle + # ---------------------------------------------------------------------------------------------- + @classmethod + def from_matrix(cls, x, y, ra, dec, matrix): + '''Initiate the class with a matrix.''' - #---------------------------------------------------------------------------------------------- - @classmethod - def from_matrix(cls, x, y, ra, dec, matrix): - '''Initiate the class with a matrix.''' + assert np.shape(matrix) == (2, 2), 'Matrix must be 2x2' - assert np.shape(matrix) == (2, 2), \ - 'Matrix must be 2x2' + scale, mirror, angle = cls._decompose_matrix(matrix) - scale, mirror, angle = cls._decompose_matrix(matrix) + return cls(x, y, ra, dec, scale, mirror, angle) - return cls(x, y, ra, dec, scale, mirror, angle) + # ---------------------------------------------------------------------------------------------- + @classmethod + def from_points(cls, xy, rd): + """Initiate the class with at least pixel + sky coordinates.""" - #---------------------------------------------------------------------------------------------- - @classmethod - def from_points(cls, xy, rd): - """Initiate the class with at least pixel + sky coordinates.""" + assert np.shape(xy) == np.shape(rd) == (len(xy), 2) and len( + xy) > 2, 'Arguments must be lists of at least 3 sets of coordinates' - assert np.shape(xy) == np.shape(rd) == (len(xy), 2) and len(xy) > 2, \ - 'Arguments must be lists of at least 3 sets of coordinates' + xy, rd = np.array(xy), np.array(rd) - xy, rd = np.array(xy), np.array(rd) + x, y, ra, dec, matrix = cls._solve_from_points(xy, rd) + scale, mirror, angle = cls._decompose_matrix(matrix) - x, y, ra, dec, matrix = cls._solve_from_points(xy, rd) - scale, mirror, angle = cls._decompose_matrix(matrix) + return cls(x, y, ra, dec, scale, mirror, angle) - return cls(x, y, ra, dec, scale, mirror, angle) + # ---------------------------------------------------------------------------------------------- + @classmethod + def from_astropy_wcs(cls, astropy_wcs): + """Initiate the class with an astropy.wcs.WCS object.""" - #---------------------------------------------------------------------------------------------- - @classmethod - def from_astropy_wcs(cls, astropy_wcs): - """Initiate the class with an astropy.wcs.WCS object.""" + if not isinstance(astropy_wcs, astropy.wcs.WCS): + raise ValueError('Must be astropy.wcs.WCS') - if not isinstance(astropy_wcs, astropy.wcs.WCS): - raise ValueError('Must be astropy.wcs.WCS') + (x, y), (ra, dec) = astropy_wcs.wcs.crpix, astropy_wcs.wcs.crval + scale, mirror, angle = cls._decompose_matrix(astropy_wcs.pixel_scale_matrix) - (x, y), (ra, dec) = astropy_wcs.wcs.crpix, astropy_wcs.wcs.crval - scale, mirror, angle = cls._decompose_matrix( - astropy_wcs.pixel_scale_matrix) + return cls(x, y, ra, dec, scale, mirror, angle) - return cls(x, y, ra, dec, scale, mirror, angle) + # ---------------------------------------------------------------------------------------------- + def adjust_with_points(self, xy, rd): + """ + Adjust the WCS with pixel + sky coordinates. - #---------------------------------------------------------------------------------------------- - def adjust_with_points(self, xy, rd): - """ - Adjust the WCS with pixel + sky coordinates. + If one set is given the change will be a simple offset. + If two sets are given the offset, angle and scale will be derived. + And if more sets are given a completely new solution will be found. + """ - If one set is given the change will be a simple offset. - If two sets are given the offset, angle and scale will be derived. - And if more sets are given a completely new solution will be found. - """ + assert np.shape(xy) == np.shape(rd) == (len(xy), 2), 'Arguments must be lists of sets of coordinates' - assert np.shape(xy) == np.shape(rd) == (len(xy), 2), \ - 'Arguments must be lists of sets of coordinates' + xy, rd = np.array(xy), np.array(rd) - xy, rd = np.array(xy), np.array(rd) + self.x, self.y = xy.mean(axis=0) + self.ra, self.dec = rd.mean(axis=0) - self.x, self.y = xy.mean(axis=0) - self.ra, self.dec = rd.mean(axis=0) + A, b = xy - xy.mean(axis=0), rd - rd.mean(axis=0) + b[:, 0] *= np.cos(np.deg2rad(rd[:, 1])) - A, b = xy - xy.mean(axis=0), rd - rd.mean(axis=0) - b[:, 0] *= np.cos(np.deg2rad(rd[:, 1])) + if len(xy) == 2: - if len(xy) == 2: + M = np.diag([[-1, 1][self.mirror], 1]) - M = np.diag([[-1, 1][self.mirror], 1]) + def R(t): + return np.array([[np.cos(t), -np.sin(t)], [np.sin(t), np.cos(t)]]) - def R(t): - return np.array([[np.cos(t), -np.sin(t)], - [np.sin(t), np.cos(t)]]) + def chi2(x): + return np.power(A.dot(x[1] / 60 / 60 * R(x[0]).dot(M).T) - b, 2).sum() - def chi2(x): - return np.power( - A.dot(x[1] / 60 / 60 * R(x[0]).dot(M).T) - b, 2).sum() - self.angle, self.scale = minimize(chi2, [self.angle, self.scale]).x + self.angle, self.scale = minimize(chi2, [self.angle, self.scale]).x - elif len(xy) > 2: - matrix = np.linalg.lstsq(A, b, rcond=None)[0].T - self.scale, self.mirror, self.angle = self._decompose_matrix( - matrix) + elif len(xy) > 2: + matrix = np.linalg.lstsq(A, b, rcond=None)[0].T + self.scale, self.mirror, self.angle = self._decompose_matrix(matrix) - #---------------------------------------------------------------------------------------------- - @property - def matrix(self): + # ---------------------------------------------------------------------------------------------- + @property + def matrix(self): - scale = self.scale / 60 / 60 - mirror = np.diag([[-1, 1][self.mirror], 1]) - angle = np.deg2rad(self.angle) + scale = self.scale / 60 / 60 + mirror = np.diag([[-1, 1][self.mirror], 1]) + angle = np.deg2rad(self.angle) - matrix = np.array([[np.cos(angle), -np.sin(angle)], - [np.sin(angle), np.cos(angle)]]) + matrix = np.array([[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]]) - return scale * matrix @ mirror + return scale * matrix @ mirror - #---------------------------------------------------------------------------------------------- - @property - def astropy_wcs(self): - wcs = astropy.wcs.WCS() - wcs.wcs.crpix = self.x, self.y - wcs.wcs.crval = self.ra, self.dec - wcs.wcs.pc = self.matrix - return wcs + # ---------------------------------------------------------------------------------------------- + @property + def astropy_wcs(self): + wcs = astropy.wcs.WCS() + wcs.wcs.crpix = self.x, self.y + wcs.wcs.crval = self.ra, self.dec + wcs.wcs.pc = self.matrix + return wcs - #---------------------------------------------------------------------------------------------- - @staticmethod - def _solve_from_points(xy, rd): + # ---------------------------------------------------------------------------------------------- + @staticmethod + def _solve_from_points(xy, rd): - (x, y), (ra, dec) = xy.mean(axis=0), rd.mean(axis=0) + (x, y), (ra, dec) = xy.mean(axis=0), rd.mean(axis=0) - A, b = xy - xy.mean(axis=0), rd - rd.mean(axis=0) - b[:, 0] *= np.cos(np.deg2rad(rd[:, 1])) + A, b = xy - xy.mean(axis=0), rd - rd.mean(axis=0) + b[:, 0] *= np.cos(np.deg2rad(rd[:, 1])) - matrix = np.linalg.lstsq(A, b, rcond=None)[0].T + matrix = np.linalg.lstsq(A, b, rcond=None)[0].T - return x, y, ra, dec, matrix + return x, y, ra, dec, matrix - #---------------------------------------------------------------------------------------------- - @staticmethod - def _decompose_matrix(matrix): + # ---------------------------------------------------------------------------------------------- + @staticmethod + def _decompose_matrix(matrix): - scale = np.sqrt(np.power(matrix, 2).sum() / 2) * 60 * 60 + scale = np.sqrt(np.power(matrix, 2).sum() / 2) * 60 * 60 - if np.argmax(np.power(matrix[0], 2)): - mirror = True if np.sign(matrix[0, 1]) != np.sign( - matrix[1, 0]) else False - else: - mirror = True if np.sign(matrix[0, 0]) == np.sign( - matrix[1, 1]) else False + if np.argmax(np.power(matrix[0], 2)): + mirror = True if np.sign(matrix[0, 1]) != np.sign(matrix[1, 0]) else False + else: + mirror = True if np.sign(matrix[0, 0]) == np.sign(matrix[1, 1]) else False - matrix = matrix if mirror else matrix.dot(np.diag([-1, 1])) + matrix = matrix if mirror else matrix.dot(np.diag([-1, 1])) - matrix3d = np.eye(3) - matrix3d[:2, :2] = matrix / (scale / 60 / 60) - angle = Rotation.from_matrix(matrix3d).as_euler('xyz', degrees=True)[2] + matrix3d = np.eye(3) + matrix3d[:2, :2] = matrix / (scale / 60 / 60) + angle = Rotation.from_matrix(matrix3d).as_euler('xyz', degrees=True)[2] - return scale, mirror, angle + return scale, mirror, angle - #---------------------------------------------------------------------------------------------- - def __setattr__(self, name, value): + # ---------------------------------------------------------------------------------------------- + def __setattr__(self, name, value): - if name == 'ra' and (value < 0 or value >= 360): - raise ValueError("0 <= R.A. < 360") + if name == 'ra' and (value < 0 or value >= 360): + raise ValueError("0 <= R.A. < 360") - elif name == 'dec' and (value < -180 or value > 180): - raise ValueError("-180 <= Dec. <= 180") + elif name == 'dec' and (value < -180 or value > 180): + raise ValueError("-180 <= Dec. <= 180") - elif name == 'scale' and value <= 0: - raise ValueError("Scale > 0") + elif name == 'scale' and value <= 0: + raise ValueError("Scale > 0") - elif name == 'mirror' and not isinstance(value, bool): - raise ValueError('mirror must be boolean') + elif name == 'mirror' and not isinstance(value, bool): + raise ValueError('mirror must be boolean') - elif name == 'angle' and (value <= -180 or value > 180): - raise ValueError("-180 < Angle <= 180") + elif name == 'angle' and (value <= -180 or value > 180): + raise ValueError("-180 < Angle <= 180") - super().__setattr__(name, value) + super().__setattr__(name, value) - #---------------------------------------------------------------------------------------------- - def __call__(self, **kwargs): - '''Make a copy with, or a copy with changes.''' + # ---------------------------------------------------------------------------------------------- + def __call__(self, **kwargs): + '''Make a copy with, or a copy with changes.''' - keys = ('x', 'y', 'ra', 'dec', 'scale', 'mirror', 'angle') + keys = ('x', 'y', 'ra', 'dec', 'scale', 'mirror', 'angle') - if not all(k in keys for k in kwargs): - raise ValueError('unknown argument(s)') + if not all(k in keys for k in kwargs): + raise ValueError('unknown argument(s)') - obj = deepcopy(self) - for k, v in kwargs.items(): - obj.__setattr__(k, v) - return obj + obj = deepcopy(self) + for k, v in kwargs.items(): + obj.__setattr__(k, v) + return obj - #---------------------------------------------------------------------------------------------- - def __repr__(self): - ra, dec = self.astropy_wcs.wcs_pix2world([(0, 0)], 0)[0] - return f'WCS2(0, 0, {ra:.4f}, {dec:.4f}, {self.scale:.2f}, {self.mirror}, {self.angle:.2f})' + # ---------------------------------------------------------------------------------------------- + def __repr__(self): + ra, dec = self.astropy_wcs.wcs_pix2world([(0, 0)], 0)[0] + return f'WCS2(0, 0, {ra:.4f}, {dec:.4f}, {self.scale:.2f}, {self.mirror}, {self.angle:.2f})' diff --git a/flows/epsfbuilder/epsfbuilder.py b/flows/epsfbuilder/epsfbuilder.py index 32c48d2..d8f8ebd 100644 --- a/flows/epsfbuilder/epsfbuilder.py +++ b/flows/epsfbuilder/epsfbuilder.py @@ -10,51 +10,44 @@ from scipy.interpolate import griddata import photutils.psf -class FlowsEPSFBuilder(photutils.psf.EPSFBuilder): - def _create_initial_epsf(self, stars): - - epsf = super()._create_initial_epsf(stars) - epsf.origin = None - X, Y = np.meshgrid(*map(np.arange, epsf.shape[::-1])) - - X = X / epsf.oversampling[0] - epsf.x_origin - Y = Y / epsf.oversampling[1] - epsf.y_origin +class FlowsEPSFBuilder(photutils.psf.EPSFBuilder): + def _create_initial_epsf(self, stars): + epsf = super()._create_initial_epsf(stars) + epsf.origin = None - self._epsf_xy_grid = X, Y + X, Y = np.meshgrid(*map(np.arange, epsf.shape[::-1])) - return epsf + X = X / epsf.oversampling[0] - epsf.x_origin + Y = Y / epsf.oversampling[1] - epsf.y_origin - def _resample_residual(self, star, epsf): + self._epsf_xy_grid = X, Y - #max_dist = .5 / np.sqrt(np.sum(np.power(epsf.oversampling, 2))) + return epsf - #star_points = list(zip(star._xidx_centered, star._yidx_centered)) - #epsf_points = list(zip(*map(np.ravel, self._epsf_xy_grid))) + def _resample_residual(self, star, epsf): + # max_dist = .5 / np.sqrt(np.sum(np.power(epsf.oversampling, 2))) - #star_tree = cKDTree(star_points) - #dd, ii = star_tree.query(epsf_points, distance_upper_bound=max_dist) - #mask = np.isfinite(dd) + # star_points = list(zip(star._xidx_centered, star._yidx_centered)) + # epsf_points = list(zip(*map(np.ravel, self._epsf_xy_grid))) - #star_data = np.full_like(epsf.data, np.nan) - #star_data.ravel()[mask] = star._data_values_normalized[ii[mask]] + # star_tree = cKDTree(star_points) + # dd, ii = star_tree.query(epsf_points, distance_upper_bound=max_dist) + # mask = np.isfinite(dd) - star_points = list(zip(star._xidx_centered, star._yidx_centered)) - star_data = griddata(star_points, star._data_values_normalized, - self._epsf_xy_grid) + # star_data = np.full_like(epsf.data, np.nan) + # star_data.ravel()[mask] = star._data_values_normalized[ii[mask]] - return star_data - epsf._data + star_points = list(zip(star._xidx_centered, star._yidx_centered)) + star_data = griddata(star_points, star._data_values_normalized, self._epsf_xy_grid) - def __call__(self, *args, **kwargs): + return star_data - epsf._data - t0 = time.time() + def __call__(self, *args, **kwargs): + t0 = time.time() - epsf, stars = super().__call__(*args, **kwargs) + epsf, stars = super().__call__(*args, **kwargs) - epsf.fit_info = dict( - n_iter=len(self._epsf), - max_iters=self.maxiters, - time=time.time() - t0, - ) + epsf.fit_info = dict(n_iter=len(self._epsf), max_iters=self.maxiters, time=time.time() - t0, ) - return epsf, stars + return epsf, stars diff --git a/flows/load_image.py b/flows/load_image.py index 3693582..ab06bc8 100644 --- a/flows/load_image.py +++ b/flows/load_image.py @@ -16,380 +16,326 @@ from astropy.wcs import WCS, FITSFixedWarning from . import api -#-------------------------------------------------------------------------------------------------- + +# -------------------------------------------------------------------------------------------------- def edge_mask(img, value=0): - """ - Create boolean mask of given value near edge of image. - - Parameters: - img (ndarray): Image of - value (float): Value to detect near edge. Default=0. - - Returns: - ndarray: Pixel mask with given values on the edge of image. - - .. codeauthor:: Rasmus Handberg - """ - - mask1 = (img == value) - mask = np.zeros_like(img, dtype='bool') - - # Mask entire rows and columns which are only the value: - mask[np.all(mask1, axis=1), :] = True - mask[:, np.all(mask1, axis=0)] = True - - # Detect "uneven" edges column-wise in image: - a = np.argmin(mask1, axis=0) - b = np.argmin(np.flipud(mask1), axis=0) - for col in range(img.shape[1]): - if mask1[0, col]: - mask[:a[col], col] = True - if mask1[-1, col]: - mask[-b[col]:, col] = True - - # Detect "uneven" edges row-wise in image: - a = np.argmin(mask1, axis=1) - b = np.argmin(np.fliplr(mask1), axis=1) - for row in range(img.shape[0]): - if mask1[row, 0]: - mask[row, :a[row]] = True - if mask1[row, -1]: - mask[row, -b[row]:] = True - - return mask - -#-------------------------------------------------------------------------------------------------- + """ + Create boolean mask of given value near edge of image. + + Parameters: + img (ndarray): Image of + value (float): Value to detect near edge. Default=0. + + Returns: + ndarray: Pixel mask with given values on the edge of image. + + .. codeauthor:: Rasmus Handberg + """ + + mask1 = (img == value) + mask = np.zeros_like(img, dtype='bool') + + # Mask entire rows and columns which are only the value: + mask[np.all(mask1, axis=1), :] = True + mask[:, np.all(mask1, axis=0)] = True + + # Detect "uneven" edges column-wise in image: + a = np.argmin(mask1, axis=0) + b = np.argmin(np.flipud(mask1), axis=0) + for col in range(img.shape[1]): + if mask1[0, col]: + mask[:a[col], col] = True + if mask1[-1, col]: + mask[-b[col]:, col] = True + + # Detect "uneven" edges row-wise in image: + a = np.argmin(mask1, axis=1) + b = np.argmin(np.fliplr(mask1), axis=1) + for row in range(img.shape[0]): + if mask1[row, 0]: + mask[row, :a[row]] = True + if mask1[row, -1]: + mask[row, -b[row]:] = True + + return mask + + +# -------------------------------------------------------------------------------------------------- def load_image(FILENAME, target_coord=None): - """ - Load FITS image. - - Parameters: - FILENAME (str): Path to FITS file to be loaded. - target_coord (:class:`astropy.coordinates.SkyCoord`): Coordinates of target. - Only used for HAWKI images to determine which image extension to load, - for all other images it is ignored. - - Returns: - object: Image constainer. - - .. codeauthor:: Rasmus Handberg - """ - - logger = logging.getLogger(__name__) - - # Get image and WCS, find stars, remove galaxies - image = type('image', (object,), dict()) # image container - - # get image and wcs solution - with fits.open(FILENAME, mode='readonly') as hdul: - - hdr = hdul[0].header - image.header = hdr - origin = hdr.get('ORIGIN', '') - telescope = hdr.get('TELESCOP', '') - instrument = hdr.get('INSTRUME', '') - - # Load image data: - image.image = np.asarray(hdul[0].data, dtype='float64') - image.shape = image.image.shape - - # Load image mask: - if origin == 'LCOGT': - if 'BPM' in hdul: - image.mask = np.asarray(hdul['BPM'].data, dtype='bool') - else: - logger.warning('LCOGT image does not contain bad pixel map. Not applying mask.') - image.mask = np.zeros_like(image.image, dtype='bool') - else: - image.mask = np.zeros_like(image.image, dtype='bool') - - image.mask |= ~np.isfinite(image.image) - - # World Coordinate System: - with warnings.catch_warnings(): - warnings.simplefilter('ignore', category=FITSFixedWarning) - image.wcs = WCS(header=hdr, relax=True) - - # Values which will be filled out below, depending on the instrument: - image.exptime = hdr.get('EXPTIME', None) # Exposure time * u.second - image.peakmax = None # Maximum value above which data is not to be trusted - - # Timestamp: - if origin == 'LCOGT': - sites = api.sites.get_all_sites() - site_keywords = {s['site_keyword']: s for s in sites} - image.site = site_keywords.get(hdr['SITE'], None) - - observatory = coords.EarthLocation.from_geodetic(lat=hdr['LATITUDE'], lon=hdr['LONGITUD'], height=hdr['HEIGHT']) - image.obstime = Time(hdr['DATE-OBS'], format='isot', scale='utc', location=observatory) - image.obstime += 0.5*image.exptime * u.second # Make time centre of exposure - - image.photfilter = { - 'zs': 'zp' - }.get(hdr['FILTER'], hdr['FILTER']) - - # Get non-linear limit - # TODO: Use actual or some fraction of the non-linearity limit - #image.peakmax = hdr.get('MAXLIN') # Presumed non-linearity limit from header - image.peakmax = 60000 # From experience, this one is better. - - elif origin == 'ESO-PARANAL' and telescope == 'ESO-VLT-U4' and instrument == 'HAWKI' and hdr.get('PRODCATG') == 'SCIENCE.MEFIMAGE': - image.site = api.get_site(2) # Hard-coded the siteid for ESO Paranal, VLT, UT4 - image.obstime = Time(hdr['DATE-OBS'], format='isot', scale='utc', location=image.site['EarthLocation']) - image.obstime += 0.5*image.exptime * u.second # Make time centre of exposure - image.photfilter = hdr['FILTER'] - - # For HAWKI multi-extension images we search the extensions for which one contains - # the target, and overwrites the image data with that: - if target_coord is None: - raise ValueError("TARGET_COORD is needed for HAWKI images to find the correct extension") - target_radec = [[target_coord.icrs.ra.deg, target_coord.icrs.dec.deg]] - for k in range(1, 5): - w = WCS(header=hdul[k].header, relax=True) - s = [hdul[k].header['NAXIS2'], hdul[k].header['NAXIS1']] - pix = w.all_world2pix(target_radec, 0).flatten() - if pix[0] >= -0.5 and pix[1] >= -0.5 and pix[0] <= s[1]-0.5 and pix[1] <= s[0]-0.5: - image.image = np.asarray(hdul[k].data, dtype='float64') - image.shape = image.image.shape - image.wcs = w - image.mask = ~np.isfinite(image.image) - break - else: - raise RuntimeError("Could not find image extension that target is on") - - elif telescope == 'NOT' and instrument in ('ALFOSC FASU', 'ALFOSC_FASU') and hdr.get('OBS_MODE', '').lower() == 'imaging': - image.site = api.get_site(5) # Hard-coded the siteid for NOT - image.obstime = Time(hdr['DATE-AVG'], format='isot', scale='utc', location=image.site['EarthLocation']) - - # Sometimes data from NOT does not have the FILTER keyword, - # in which case we have to try to figure out which filter - # was used based on some of the other headers: - if 'FILTER' in hdr: - image.photfilter = { - 'B Bes': 'B', - 'V Bes': 'V', - 'R Bes': 'R', - 'g SDSS': 'gp', - 'r SDSS': 'rp', - 'i SDSS': 'ip', - 'i int': 'ip', # Interference filter - 'u SDSS': 'up', - 'z SDSS': 'zp' - }.get(hdr['FILTER'].replace('_', ' '), hdr['FILTER']) - else: - filters_used = [] - for check_headers in ('ALFLTNM', 'FAFLTNM', 'FBFLTNM'): - if hdr.get(check_headers) and hdr.get(check_headers).strip().lower() != 'open': - filters_used.append(hdr.get(check_headers).strip()) - if len(filters_used) == 1: - image.photfilter = { - 'B_Bes 440_100': 'B', - 'V_Bes 530_80': 'V', - 'R_Bes 650_130': 'R', - "g'_SDSS 480_145": 'gp', - "r'_SDSS 618_148": 'rp', - "i'_SDSS 771_171": 'ip', - 'i_int 797_157': 'ip', # Interference filter - "z'_SDSS 832_LP": 'zp' - }.get(filters_used[0].replace(' ', ' '), filters_used[0]) - else: - raise RuntimeError("Could not determine filter used.") - - # Get non-linear limit - # Obtained from http://www.not.iac.es/instruments/detectors/CCD14/LED-linearity/20181026-200-1x1.pdf - # TODO: grab these from a table for all detector setups of ALFOSC - image.peakmax = 80000 # For ALFOSC D, 1x1, 200; the standard for SNe. - - elif telescope == 'NOT' and instrument == 'NOTCAM' and hdr.get('OBS_MODE', '').lower() == 'imaging': - image.site = api.get_site(5) # Hard-coded the siteid for NOT - image.obstime = Time(hdr['DATE-AVG'], format='isot', scale='utc', location=image.site['EarthLocation']) - - # Does NOTCAM data sometimes contain a FILTER header? - # if not we have to try to figure out which filter - # was used based on some of the other headers: - if 'FILTER' in hdr: - raise RuntimeError("NOTCAM: Filter keyword defined") - filters_used = [] - for check_headers in ('NCFLTNM1', 'NCFLTNM2'): - if hdr.get(check_headers) and hdr.get(check_headers).strip().lower() != 'open': - filters_used.append(hdr.get(check_headers).strip()) - if len(filters_used) == 1: - image.photfilter = { - 'Ks': 'K' - }.get(filters_used[0], filters_used[0]) - else: - raise RuntimeError("Could not determine filter used.") - - # Mask out "halo" of pixels with zero value along edge of image: - image.mask |= edge_mask(image.image, value=0) - - elif hdr.get('FPA.TELESCOPE') == 'PS1' and hdr.get('FPA.INSTRUMENT') == 'GPC1': - image.site = api.get_site(6) # Hard-coded the siteid for Pan-STARRS1 - image.obstime = Time(hdr['MJD-OBS'], format='mjd', scale='utc', location=image.site['EarthLocation']) - - image.photfilter = { - 'g.00000': 'gp', - 'r.00000': 'rp', - 'i.00000': 'ip', - 'z.00000': 'zp' - }.get(hdr['FPA.FILTER'], hdr['FPA.FILTER']) - - elif telescope == 'Liverpool Telescope': - # Liverpool telescope - image.site = api.get_site(8) # Hard-coded the siteid for Liverpool Telescope - image.obstime = Time(hdr['DATE-OBS'], format='isot', scale='utc', location=image.site['EarthLocation']) - image.obstime += 0.5*image.exptime * u.second # Make time centre of exposure - image.photfilter = { - 'Bessel-B': 'B', - 'Bessell-B': 'B', - 'Bessel-V': 'V', - 'Bessell-V': 'V', - 'SDSS-U': 'up', - 'SDSS-G': 'gp', - 'SDSS-R': 'rp', - 'SDSS-I': 'ip', - 'SDSS-Z': 'zp' - }.get(hdr['FILTER1'], hdr['FILTER1']) - - elif telescope == 'CA 3.5m' and instrument == 'Omega2000': - # Calar Alto 3.5m (Omege2000) - image.site = api.get_site(9) # Hard-coded the siteid for Calar Alto 3.5m - image.obstime = Time(hdr['MJD-OBS'], format='mjd', scale='utc', location=image.site['EarthLocation']) - image.obstime += 0.5*image.exptime * u.second # Make time centre of exposure - image.photfilter = hdr['FILTER'] - - elif telescope == 'SWO' and hdr.get('SITENAME') == 'LCO': - image.site = api.get_site(10) # Hard-coded the siteid for Swope, Las Campanas Observatory - image.obstime = Time(hdr['JD'], format='jd', scale='utc', location=image.site['EarthLocation']) - image.photfilter = { - 'u': 'up', - 'g': 'gp', - 'r': 'rp', - 'i': 'ip', - }.get(hdr['FILTER'], hdr['FILTER']) - - elif telescope == 'DUP' and hdr.get('SITENAME') == 'LCO' and instrument == 'Direct/SITe2K-1': - image.site = api.get_site(14) # Hard-coded the siteid for Du Pont, Las Campanas Observatory - image.obstime = Time(hdr['JD'], format='jd', scale='utc', location=image.site['EarthLocation']) - image.photfilter = { - 'u': 'up', - 'g': 'gp', - 'r': 'rp', - 'i': 'ip', - }.get(hdr['FILTER'], hdr['FILTER']) - - elif telescope == 'DUP' and instrument == 'RetroCam': - image.site = api.get_site(16) # Hard-coded the siteid for Du Pont, Las Campanas Observatory - image.obstime = Time(hdr['JD'], format='jd', scale='utc', location=image.site['EarthLocation']) - image.photfilter = { - 'Yc': 'Y', - 'Hc': 'H', - 'Jo': 'J', - }.get(hdr['FILTER'], hdr['FILTER']) - - elif telescope == 'Baade' and hdr.get('SITENAME') == 'LCO' and instrument == 'FourStar': - image.site = api.get_site(11) # Hard-coded the siteid for Swope, Las Campanas Observatory - image.obstime = Time(hdr['JD'], format='jd', scale='utc', location=image.site['EarthLocation']) - image.photfilter = { - 'Ks': 'K', - 'J1': 'Y', - }.get(hdr['FILTER'], hdr['FILTER']) - image.exptime *= int(hdr['NCOMBINE']) # EXPTIME is only for a single exposure - - elif instrument == 'SOFI' and telescope in ('ESO-NTT', 'other') and (origin == 'ESO' or origin.startswith('NOAO-IRAF')): - image.site = api.get_site(12) # Hard-coded the siteid for SOFT, ESO NTT - if 'TMID' in hdr: - image.obstime = Time(hdr['TMID'], format='mjd', scale='utc', location=image.site['EarthLocation']) - else: - image.obstime = Time(hdr['MJD-OBS'], format='mjd', scale='utc', location=image.site['EarthLocation']) - image.obstime += 0.5*image.exptime * u.second # Make time centre of exposure - - # Photometric filter: - photfilter_translate = { - 'Ks': 'K' - } - if 'FILTER' in hdr: - image.photfilter = photfilter_translate.get(hdr['FILTER'], hdr['FILTER']) - else: - filters_used = [] - for check_headers in ('ESO INS FILT1 ID', 'ESO INS FILT2 ID'): - if hdr.get(check_headers) and hdr.get(check_headers).strip().lower() != 'open': - filters_used.append(hdr.get(check_headers).strip()) - if len(filters_used) == 1: - image.photfilter = photfilter_translate.get(filters_used[0], filters_used[0]) - else: - raise RuntimeError("Could not determine filter used.") - - # Mask out "halo" of pixels with zero value along edge of image: - image.mask |= edge_mask(image.image, value=0) - - elif telescope == 'ESO-NTT' and instrument == 'EFOSC' and (origin == 'ESO' or origin.startswith('NOAO-IRAF')): - image.site = api.get_site(15) # Hard-coded the siteid for EFOSC, ESO NTT - image.obstime = Time(hdr['DATE-OBS'], format='isot', scale='utc', location=image.site['EarthLocation']) - image.obstime += 0.5*image.exptime * u.second # Make time centre of exposure - image.photfilter = { - 'g782': 'gp', - 'r784': 'rp', - 'i705': 'ip', - 'B639': 'B', - 'V641': 'V' - }.get(hdr['FILTER'], hdr['FILTER']) - - elif telescope == 'SAI-2.5' and instrument == 'ASTRONIRCAM': - image.site = api.get_site(13) # Hard-coded the siteid for Caucasus Mountain Observatory - if 'MIDPOINT' in hdr: - image.obstime = Time(hdr['MIDPOINT'], format='isot', scale='utc', location=image.site['EarthLocation']) - else: - image.obstime = Time(hdr['MJD-AVG'], format='mjd', scale='utc', location=image.site['EarthLocation']) - image.photfilter = { - 'H_Open': 'H', - 'K_Open': 'K', - }.get(hdr['FILTER'], hdr['FILTER']) - image.exptime = hdr.get('FULL_EXP', image.exptime) - - elif instrument == 'OMEGACAM' and (origin == 'ESO' or origin.startswith('NOAO-IRAF')): - image.site = api.get_site(18) # Hard-coded the siteid for ESO VLT Survey telescope - image.obstime = Time(hdr['MJD-OBS'], format='mjd', scale='utc', location=image.site['EarthLocation']) - image.obstime += 0.5*image.exptime * u.second # Make time centre of exposure - image.photfilter = { - 'i_SDSS': 'ip' - }.get(hdr['ESO INS FILT1 NAME'], hdr['ESO INS FILT1 NAME']) - - elif instrument == 'ANDICAM-CCD' and hdr.get('OBSERVAT') == 'CTIO': - image.site = api.get_site(20) # Hard-coded the siteid for ANDICAM at Cerro Tololo Interamerican Observatory (CTIO) - image.obstime = Time(hdr['JD'], format='jd', scale='utc', location=image.site['EarthLocation']) - image.obstime += 0.5*image.exptime * u.second # Make time centre of exposure - image.photfilter = hdr['CCDFLTID'] - - elif telescope == '1.3m PAIRITEL' and instrument == '2MASS Survey cam': - image.site = api.get_site(21) # Hard-coded the siteid for Peters Automated InfraRed Imaging TELescope - time_start = Time(hdr['STRT_CPU'], format='iso', scale='utc', location=image.site['EarthLocation']) - time_stop = Time(hdr['STOP_CPU'], format='iso', scale='utc', location=image.site['EarthLocation']) - image.obstime = time_start + 0.5*(time_stop - time_start) - image.photfilter = { - 'j': 'J', - 'h': 'H', - 'k': 'K', - }.get(hdr['FILTER'], hdr['FILTER']) - - # Mask out "halo" of pixels with zero value along edge of image: - image.mask |= edge_mask(image.image, value=0) - - elif (origin == 'OAdM' or origin.startswith('NOAO-IRAF')) and telescope == 'TJO' and instrument in ('MEIA3', 'MEIA2'): - image.site = api.get_site(22) # Hard-coded the siteid for Telescopi Joan Oró (TJO) at Observatori Astronòmic del Montsec - image.obstime = Time(hdr['JD'], format='jd', scale='utc', location=image.site['EarthLocation']) - image.obstime += 0.5*image.exptime * u.second # Make time centre of exposure - image.photfilter = hdr['FILTER'] - - else: - raise RuntimeError("Could not determine origin of image") - - # Sanity checks: - if image.exptime is None: - raise ValueError("Image exposure time could not be extracted") - - # Create masked version of image: - image.image[image.mask] = np.NaN - image.clean = np.ma.masked_array(data=image.image, mask=image.mask, copy=False) - - return image + """ + Load FITS image. + + Parameters: + FILENAME (str): Path to FITS file to be loaded. + target_coord (:class:`astropy.coordinates.SkyCoord`): Coordinates of target. + Only used for HAWKI images to determine which image extension to load, + for all other images it is ignored. + + Returns: + object: Image constainer. + + .. codeauthor:: Rasmus Handberg + """ + + logger = logging.getLogger(__name__) + + # Get image and WCS, find stars, remove galaxies + image = type('image', (object,), dict()) # image container + + # get image and wcs solution + with fits.open(FILENAME, mode='readonly') as hdul: + + hdr = hdul[0].header + image.header = hdr + origin = hdr.get('ORIGIN', '') + telescope = hdr.get('TELESCOP', '') + instrument = hdr.get('INSTRUME', '') + + # Load image data: + image.image = np.asarray(hdul[0].data, dtype='float64') + image.shape = image.image.shape + + # Load image mask: + if origin == 'LCOGT': + if 'BPM' in hdul: + image.mask = np.asarray(hdul['BPM'].data, dtype='bool') + else: + logger.warning('LCOGT image does not contain bad pixel map. Not applying mask.') + image.mask = np.zeros_like(image.image, dtype='bool') + else: + image.mask = np.zeros_like(image.image, dtype='bool') + + image.mask |= ~np.isfinite(image.image) + + # World Coordinate System: + with warnings.catch_warnings(): + warnings.simplefilter('ignore', category=FITSFixedWarning) + image.wcs = WCS(header=hdr, relax=True) + + # Values which will be filled out below, depending on the instrument: + image.exptime = hdr.get('EXPTIME', None) # Exposure time * u.second + image.peakmax = None # Maximum value above which data is not to be trusted + + # Timestamp: + if origin == 'LCOGT': + sites = api.sites.get_all_sites() + site_keywords = {s['site_keyword']: s for s in sites} + image.site = site_keywords.get(hdr['SITE'], None) + + observatory = coords.EarthLocation.from_geodetic(lat=hdr['LATITUDE'], lon=hdr['LONGITUD'], + height=hdr['HEIGHT']) + image.obstime = Time(hdr['DATE-OBS'], format='isot', scale='utc', location=observatory) + image.obstime += 0.5 * image.exptime * u.second # Make time centre of exposure + + image.photfilter = {'zs': 'zp'}.get(hdr['FILTER'], hdr['FILTER']) + + # Get non-linear limit + # TODO: Use actual or some fraction of the non-linearity limit + # image.peakmax = hdr.get('MAXLIN') # Presumed non-linearity limit from header + image.peakmax = 60000 # From experience, this one is better. + + elif origin == 'ESO-PARANAL' and telescope == 'ESO-VLT-U4' and instrument == 'HAWKI' and hdr.get( + 'PRODCATG') == 'SCIENCE.MEFIMAGE': + image.site = api.get_site(2) # Hard-coded the siteid for ESO Paranal, VLT, UT4 + image.obstime = Time(hdr['DATE-OBS'], format='isot', scale='utc', location=image.site['EarthLocation']) + image.obstime += 0.5 * image.exptime * u.second # Make time centre of exposure + image.photfilter = hdr['FILTER'] + + # For HAWKI multi-extension images we search the extensions for which one contains + # the target, and overwrites the image data with that: + if target_coord is None: + raise ValueError("TARGET_COORD is needed for HAWKI images to find the correct extension") + target_radec = [[target_coord.icrs.ra.deg, target_coord.icrs.dec.deg]] + for k in range(1, 5): + w = WCS(header=hdul[k].header, relax=True) + s = [hdul[k].header['NAXIS2'], hdul[k].header['NAXIS1']] + pix = w.all_world2pix(target_radec, 0).flatten() + if pix[0] >= -0.5 and pix[1] >= -0.5 and pix[0] <= s[1] - 0.5 and pix[1] <= s[0] - 0.5: + image.image = np.asarray(hdul[k].data, dtype='float64') + image.shape = image.image.shape + image.wcs = w + image.mask = ~np.isfinite(image.image) + break + else: + raise RuntimeError("Could not find image extension that target is on") + + elif telescope == 'NOT' and instrument in ('ALFOSC FASU', 'ALFOSC_FASU') and hdr.get('OBS_MODE', + '').lower() == 'imaging': + image.site = api.get_site(5) # Hard-coded the siteid for NOT + image.obstime = Time(hdr['DATE-AVG'], format='isot', scale='utc', location=image.site['EarthLocation']) + + # Sometimes data from NOT does not have the FILTER keyword, + # in which case we have to try to figure out which filter + # was used based on some of the other headers: + if 'FILTER' in hdr: + image.photfilter = {'B Bes': 'B', 'V Bes': 'V', 'R Bes': 'R', 'g SDSS': 'gp', 'r SDSS': 'rp', + 'i SDSS': 'ip', 'i int': 'ip', # Interference filter + 'u SDSS': 'up', 'z SDSS': 'zp'}.get(hdr['FILTER'].replace('_', ' '), hdr['FILTER']) + else: + filters_used = [] + for check_headers in ('ALFLTNM', 'FAFLTNM', 'FBFLTNM'): + if hdr.get(check_headers) and hdr.get(check_headers).strip().lower() != 'open': + filters_used.append(hdr.get(check_headers).strip()) + if len(filters_used) == 1: + image.photfilter = {'B_Bes 440_100': 'B', 'V_Bes 530_80': 'V', 'R_Bes 650_130': 'R', + "g'_SDSS 480_145": 'gp', "r'_SDSS 618_148": 'rp', "i'_SDSS 771_171": 'ip', + 'i_int 797_157': 'ip', # Interference filter + "z'_SDSS 832_LP": 'zp'}.get(filters_used[0].replace(' ', ' '), filters_used[0]) + else: + raise RuntimeError("Could not determine filter used.") + + # Get non-linear limit + # Obtained from http://www.not.iac.es/instruments/detectors/CCD14/LED-linearity/20181026-200-1x1.pdf + # TODO: grab these from a table for all detector setups of ALFOSC + image.peakmax = 80000 # For ALFOSC D, 1x1, 200; the standard for SNe. + + elif telescope == 'NOT' and instrument == 'NOTCAM' and hdr.get('OBS_MODE', '').lower() == 'imaging': + image.site = api.get_site(5) # Hard-coded the siteid for NOT + image.obstime = Time(hdr['DATE-AVG'], format='isot', scale='utc', location=image.site['EarthLocation']) + + # Does NOTCAM data sometimes contain a FILTER header? + # if not we have to try to figure out which filter + # was used based on some of the other headers: + if 'FILTER' in hdr: + raise RuntimeError("NOTCAM: Filter keyword defined") + filters_used = [] + for check_headers in ('NCFLTNM1', 'NCFLTNM2'): + if hdr.get(check_headers) and hdr.get(check_headers).strip().lower() != 'open': + filters_used.append(hdr.get(check_headers).strip()) + if len(filters_used) == 1: + image.photfilter = {'Ks': 'K'}.get(filters_used[0], filters_used[0]) + else: + raise RuntimeError("Could not determine filter used.") + + # Mask out "halo" of pixels with zero value along edge of image: + image.mask |= edge_mask(image.image, value=0) + + elif hdr.get('FPA.TELESCOPE') == 'PS1' and hdr.get('FPA.INSTRUMENT') == 'GPC1': + image.site = api.get_site(6) # Hard-coded the siteid for Pan-STARRS1 + image.obstime = Time(hdr['MJD-OBS'], format='mjd', scale='utc', location=image.site['EarthLocation']) + + image.photfilter = {'g.00000': 'gp', 'r.00000': 'rp', 'i.00000': 'ip', 'z.00000': 'zp'}.get( + hdr['FPA.FILTER'], hdr['FPA.FILTER']) + + elif telescope == 'Liverpool Telescope': + # Liverpool telescope + image.site = api.get_site(8) # Hard-coded the siteid for Liverpool Telescope + image.obstime = Time(hdr['DATE-OBS'], format='isot', scale='utc', location=image.site['EarthLocation']) + image.obstime += 0.5 * image.exptime * u.second # Make time centre of exposure + image.photfilter = {'Bessel-B': 'B', 'Bessell-B': 'B', 'Bessel-V': 'V', 'Bessell-V': 'V', 'SDSS-U': 'up', + 'SDSS-G': 'gp', 'SDSS-R': 'rp', 'SDSS-I': 'ip', 'SDSS-Z': 'zp'}.get(hdr['FILTER1'], + hdr['FILTER1']) + + elif telescope == 'CA 3.5m' and instrument == 'Omega2000': + # Calar Alto 3.5m (Omege2000) + image.site = api.get_site(9) # Hard-coded the siteid for Calar Alto 3.5m + image.obstime = Time(hdr['MJD-OBS'], format='mjd', scale='utc', location=image.site['EarthLocation']) + image.obstime += 0.5 * image.exptime * u.second # Make time centre of exposure + image.photfilter = hdr['FILTER'] + + elif telescope == 'SWO' and hdr.get('SITENAME') == 'LCO': + image.site = api.get_site(10) # Hard-coded the siteid for Swope, Las Campanas Observatory + image.obstime = Time(hdr['JD'], format='jd', scale='utc', location=image.site['EarthLocation']) + image.photfilter = {'u': 'up', 'g': 'gp', 'r': 'rp', 'i': 'ip', }.get(hdr['FILTER'], hdr['FILTER']) + + elif telescope == 'DUP' and hdr.get('SITENAME') == 'LCO' and instrument == 'Direct/SITe2K-1': + image.site = api.get_site(14) # Hard-coded the siteid for Du Pont, Las Campanas Observatory + image.obstime = Time(hdr['JD'], format='jd', scale='utc', location=image.site['EarthLocation']) + image.photfilter = {'u': 'up', 'g': 'gp', 'r': 'rp', 'i': 'ip', }.get(hdr['FILTER'], hdr['FILTER']) + + elif telescope == 'DUP' and instrument == 'RetroCam': + image.site = api.get_site(16) # Hard-coded the siteid for Du Pont, Las Campanas Observatory + image.obstime = Time(hdr['JD'], format='jd', scale='utc', location=image.site['EarthLocation']) + image.photfilter = {'Yc': 'Y', 'Hc': 'H', 'Jo': 'J', }.get(hdr['FILTER'], hdr['FILTER']) + + elif telescope == 'Baade' and hdr.get('SITENAME') == 'LCO' and instrument == 'FourStar': + image.site = api.get_site(11) # Hard-coded the siteid for Swope, Las Campanas Observatory + image.obstime = Time(hdr['JD'], format='jd', scale='utc', location=image.site['EarthLocation']) + image.photfilter = {'Ks': 'K', 'J1': 'Y', }.get(hdr['FILTER'], hdr['FILTER']) + image.exptime *= int(hdr['NCOMBINE']) # EXPTIME is only for a single exposure + + elif instrument == 'SOFI' and telescope in ('ESO-NTT', 'other') and ( + origin == 'ESO' or origin.startswith('NOAO-IRAF')): + image.site = api.get_site(12) # Hard-coded the siteid for SOFT, ESO NTT + if 'TMID' in hdr: + image.obstime = Time(hdr['TMID'], format='mjd', scale='utc', location=image.site['EarthLocation']) + else: + image.obstime = Time(hdr['MJD-OBS'], format='mjd', scale='utc', location=image.site['EarthLocation']) + image.obstime += 0.5 * image.exptime * u.second # Make time centre of exposure + + # Photometric filter: + photfilter_translate = {'Ks': 'K'} + if 'FILTER' in hdr: + image.photfilter = photfilter_translate.get(hdr['FILTER'], hdr['FILTER']) + else: + filters_used = [] + for check_headers in ('ESO INS FILT1 ID', 'ESO INS FILT2 ID'): + if hdr.get(check_headers) and hdr.get(check_headers).strip().lower() != 'open': + filters_used.append(hdr.get(check_headers).strip()) + if len(filters_used) == 1: + image.photfilter = photfilter_translate.get(filters_used[0], filters_used[0]) + else: + raise RuntimeError("Could not determine filter used.") + + # Mask out "halo" of pixels with zero value along edge of image: + image.mask |= edge_mask(image.image, value=0) + + elif telescope == 'ESO-NTT' and instrument == 'EFOSC' and (origin == 'ESO' or origin.startswith('NOAO-IRAF')): + image.site = api.get_site(15) # Hard-coded the siteid for EFOSC, ESO NTT + image.obstime = Time(hdr['DATE-OBS'], format='isot', scale='utc', location=image.site['EarthLocation']) + image.obstime += 0.5 * image.exptime * u.second # Make time centre of exposure + image.photfilter = {'g782': 'gp', 'r784': 'rp', 'i705': 'ip', 'B639': 'B', 'V641': 'V'}.get(hdr['FILTER'], + hdr['FILTER']) + + elif telescope == 'SAI-2.5' and instrument == 'ASTRONIRCAM': + image.site = api.get_site(13) # Hard-coded the siteid for Caucasus Mountain Observatory + if 'MIDPOINT' in hdr: + image.obstime = Time(hdr['MIDPOINT'], format='isot', scale='utc', location=image.site['EarthLocation']) + else: + image.obstime = Time(hdr['MJD-AVG'], format='mjd', scale='utc', location=image.site['EarthLocation']) + image.photfilter = {'H_Open': 'H', 'K_Open': 'K', }.get(hdr['FILTER'], hdr['FILTER']) + image.exptime = hdr.get('FULL_EXP', image.exptime) + + elif instrument == 'OMEGACAM' and (origin == 'ESO' or origin.startswith('NOAO-IRAF')): + image.site = api.get_site(18) # Hard-coded the siteid for ESO VLT Survey telescope + image.obstime = Time(hdr['MJD-OBS'], format='mjd', scale='utc', location=image.site['EarthLocation']) + image.obstime += 0.5 * image.exptime * u.second # Make time centre of exposure + image.photfilter = {'i_SDSS': 'ip'}.get(hdr['ESO INS FILT1 NAME'], hdr['ESO INS FILT1 NAME']) + + elif instrument == 'ANDICAM-CCD' and hdr.get('OBSERVAT') == 'CTIO': + image.site = api.get_site( + 20) # Hard-coded the siteid for ANDICAM at Cerro Tololo Interamerican Observatory (CTIO) + image.obstime = Time(hdr['JD'], format='jd', scale='utc', location=image.site['EarthLocation']) + image.obstime += 0.5 * image.exptime * u.second # Make time centre of exposure + image.photfilter = hdr['CCDFLTID'] + + elif telescope == '1.3m PAIRITEL' and instrument == '2MASS Survey cam': + image.site = api.get_site(21) # Hard-coded the siteid for Peters Automated InfraRed Imaging TELescope + time_start = Time(hdr['STRT_CPU'], format='iso', scale='utc', location=image.site['EarthLocation']) + time_stop = Time(hdr['STOP_CPU'], format='iso', scale='utc', location=image.site['EarthLocation']) + image.obstime = time_start + 0.5 * (time_stop - time_start) + image.photfilter = {'j': 'J', 'h': 'H', 'k': 'K', }.get(hdr['FILTER'], hdr['FILTER']) + + # Mask out "halo" of pixels with zero value along edge of image: + image.mask |= edge_mask(image.image, value=0) + + elif (origin == 'OAdM' or origin.startswith('NOAO-IRAF')) and telescope == 'TJO' and instrument in ( + 'MEIA3', 'MEIA2'): + image.site = api.get_site( + 22) # Hard-coded the siteid for Telescopi Joan Oró (TJO) at Observatori Astronòmic del Montsec + image.obstime = Time(hdr['JD'], format='jd', scale='utc', location=image.site['EarthLocation']) + image.obstime += 0.5 * image.exptime * u.second # Make time centre of exposure + image.photfilter = hdr['FILTER'] + + else: + raise RuntimeError("Could not determine origin of image") + + # Sanity checks: + if image.exptime is None: + raise ValueError("Image exposure time could not be extracted") + + # Create masked version of image: + image.image[image.mask] = np.NaN + image.clean = np.ma.masked_array(data=image.image, mask=image.mask, copy=False) + + return image diff --git a/flows/photometry.py b/flows/photometry.py index 3c7378c..f6d89e8 100644 --- a/flows/photometry.py +++ b/flows/photometry.py @@ -28,735 +28,649 @@ import sep warnings.simplefilter('ignore', category=AstropyDeprecationWarning) -from photutils import CircularAperture, CircularAnnulus, aperture_photometry # noqa: E402 -from photutils.psf import EPSFFitter, BasicPSFPhotometry, DAOGroup, extract_stars # noqa: E402 -from photutils import Background2D, SExtractorBackground, MedianBackground # noqa: E402 -from photutils.utils import calc_total_error # noqa: E402 - -from . import api # noqa: E402 -from . import reference_cleaning as refclean # noqa: E402 -from .config import load_config # noqa: E402 -from .plots import plt, plot_image # noqa: E402 -from .version import get_version # noqa: E402 -from .load_image import load_image # noqa: E402 -from .run_imagematch import run_imagematch # noqa: E402 -from .zeropoint import bootstrap_outlier, sigma_from_Chauvenet # noqa: E402 -from .coordinatematch import CoordinateMatch, WCS2 # noqa: E402 -from .epsfbuilder import FlowsEPSFBuilder # noqa: E402 +from photutils import CircularAperture, CircularAnnulus, aperture_photometry # noqa: E402 +from photutils.psf import EPSFFitter, BasicPSFPhotometry, DAOGroup, extract_stars # noqa: E402 +from photutils import Background2D, SExtractorBackground, MedianBackground # noqa: E402 +from photutils.utils import calc_total_error # noqa: E402 + +from . import api # noqa: E402 +from . import reference_cleaning as refclean # noqa: E402 +from .config import load_config # noqa: E402 +from .plots import plt, plot_image # noqa: E402 +from .version import get_version # noqa: E402 +from .load_image import load_image # noqa: E402 +from .run_imagematch import run_imagematch # noqa: E402 +from .zeropoint import bootstrap_outlier, sigma_from_Chauvenet # noqa: E402 +from .coordinatematch import CoordinateMatch, WCS2 # noqa: E402 +from .epsfbuilder import FlowsEPSFBuilder # noqa: E402 __version__ = get_version(pep440=False) -#-------------------------------------------------------------------------------------------------- -def photometry(fileid, output_folder=None, attempt_imagematch=True, keep_diff_fixed=False, - cm_timeout=None): - """ - Run photometry. - - Parameters: - fileid (int): File ID to process. - output_folder (str, optional): Path to directory where output should be placed. - attempt_imagematch (bool, optional): If no subtracted image is available, but a - template image is, should we attempt to run ImageMatch using standard settings. - Default=True. - keep_diff_fixed (bool, optional): Allow psf photometry to recenter when - calculating the flux for the difference image. Setting to True can help if diff - image has non-source flux in the region around the SN. - cm_timeout (float, optional): Timeout in seconds for the :class:`CoordinateMatch` algorithm. - - .. codeauthor:: Rasmus Handberg - .. codeauthor:: Emir Karamehmetoglu - .. codeauthor:: Simon Holmbo - """ - - # Settings: - ref_target_dist_limit = 10 * u.arcsec # Reference star must be further than this away to be included - - logger = logging.getLogger(__name__) - tic = default_timer() - - # Use local copy of archive if configured to do so: - config = load_config() - - # Get datafile dict from API: - datafile = api.get_datafile(fileid) - logger.debug("Datafile: %s", datafile) - targetid = datafile['targetid'] - target_name = datafile['target_name'] - photfilter = datafile['photfilter'] - - archive_local = config.get('photometry', 'archive_local', fallback=None) - if archive_local is not None: - datafile['archive_path'] = archive_local - if not os.path.isdir(datafile['archive_path']): - raise FileNotFoundError("ARCHIVE is not available: " + datafile['archive_path']) - - # Get the catalog containing the target and reference stars: - # TODO: Include proper-motion to the time of observation - catalog = api.get_catalog(targetid, output='table') - target = catalog['target'][0] - target_coord = coords.SkyCoord( - ra=target['ra'], - dec=target['decl'], - unit='deg', - frame='icrs') - - # Folder to save output: - if output_folder is None: - output_folder_root = config.get('photometry', 'output', fallback='.') - output_folder = os.path.join(output_folder_root, target_name, f'{fileid:05d}') - logger.info("Placing output in '%s'", output_folder) - os.makedirs(output_folder, exist_ok=True) - - # The paths to the science image: - filepath = os.path.join(datafile['archive_path'], datafile['path']) - - # TODO: Download datafile using API to local drive: - # TODO: Is this a security concern? - # if archive_local: - # api.download_datafile(datafile, archive_local) - - # Translate photometric filter into table column: - ref_filter = { - 'up': 'u_mag', - 'gp': 'g_mag', - 'rp': 'r_mag', - 'ip': 'i_mag', - 'zp': 'z_mag', - 'B': 'B_mag', - 'V': 'V_mag', - 'J': 'J_mag', - 'H': 'H_mag', - 'K': 'K_mag', - }.get(photfilter, None) - - if ref_filter is None: - logger.warning("Could not find filter '%s' in catalogs. Using default gp filter.", photfilter) - ref_filter = 'g_mag' - - # Load the image from the FITS file: - logger.info("Load image '%s'", filepath) - image = load_image(filepath, target_coord=target_coord) - - references = catalog['references'] - references.sort(ref_filter) - - # Check that there actually are reference stars in that filter: - if allnan(references[ref_filter]): - raise ValueError("No reference stars found in current photfilter.") - - #============================================================================================== - # BARYCENTRIC CORRECTION OF TIME - #============================================================================================== - - ltt_bary = image.obstime.light_travel_time(target_coord, ephemeris='jpl') - image.obstime = image.obstime.tdb + ltt_bary - - #============================================================================================== - # BACKGROUND ESTIMATION - #============================================================================================== - - fig, ax = plt.subplots(1, 2, figsize=(20, 18)) - plot_image(image.clean, ax=ax[0], scale='log', cbar='right', title='Image') - plot_image(image.mask, ax=ax[1], scale='linear', cbar='right', title='Mask') - fig.savefig(os.path.join(output_folder, 'original.png'), bbox_inches='tight') - plt.close(fig) - - # Estimate image background: - # Not using image.clean here, since we are redefining the mask anyway - background = Background2D(image.clean, (128, 128), - filter_size=(5, 5), - sigma_clip=SigmaClip(sigma=3.0), - bkg_estimator=SExtractorBackground(), - exclude_percentile=50.0) - - # Create background-subtracted image: - image.subclean = image.clean - background.background - - # Plot background estimation: - fig, ax = plt.subplots(1, 3, figsize=(20, 6)) - plot_image(image.clean, ax=ax[0], scale='log', title='Original') - plot_image(background.background, ax=ax[1], scale='log', title='Background') - plot_image(image.subclean, ax=ax[2], scale='log', title='Background subtracted') - fig.savefig(os.path.join(output_folder, 'background.png'), bbox_inches='tight') - plt.close(fig) - - # TODO: Is this correct?! - image.error = calc_total_error(image.clean, background.background_rms, 1.0) - - # Use sep to for soure extraction - sep_background = sep.Background(image.image, mask=image.mask) - objects = sep.extract(image.image - sep_background, - thresh=5., - err=sep_background.globalrms, - mask=image.mask, - deblend_cont=0.1, - minarea=9, - clean_param=2.0) - - # Cleanup large arrays which are no longer needed: - del background, fig, ax, sep_background, ltt_bary - gc.collect() - - #============================================================================================== - # DETECTION OF STARS AND MATCHING WITH CATALOG - #============================================================================================== - - # Account for proper motion: - replace(references['pm_ra'], np.NaN, 0) - replace(references['pm_dec'], np.NaN, 0) - refs_coord = coords.SkyCoord( - ra=references['ra'], - dec=references['decl'], - pm_ra_cosdec=references['pm_ra'], - pm_dec=references['pm_dec'], - unit='deg', - frame='icrs', - obstime=Time(2015.5, format='decimalyear')) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore", ErfaWarning) - refs_coord = refs_coord.apply_space_motion(new_obstime=image.obstime) - - # TODO: These need to be based on the instrument! - radius = 10 - fwhm_guess = 6.0 - fwhm_min = 3.5 - fwhm_max = 18.0 - - # Clean extracted stars - masked_sep_xy, sep_mask, masked_sep_rsqs = refclean.force_reject_g2d( - objects['x'], - objects['y'], - image, - get_fwhm=False, - radius=radius, - fwhm_guess=fwhm_guess, - rsq_min=0.3, - fwhm_max=fwhm_max, - fwhm_min=fwhm_min) - - logger.info("Finding new WCS solution...") - head_wcs = str(WCS2.from_astropy_wcs(image.wcs)) - logger.debug('Head WCS: %s', head_wcs) - - # Solve for new WCS - cm = CoordinateMatch( - xy=list(masked_sep_xy[sep_mask]), - rd=list(zip(refs_coord.ra.deg, refs_coord.dec.deg)), - xy_order=np.argsort( - np.power(masked_sep_xy[sep_mask] - np.array(image.shape[::-1]) / 2, - 2).sum(axis=1)), - rd_order=np.argsort(target_coord.separation(refs_coord)), - xy_nmax=100, - rd_nmax=100, - maximum_angle_distance=0.002) - - # Set timeout par to infinity unless specified. - if cm_timeout is None: - cm_timeout = float('inf') - try: - i_xy, i_rd = map(np.array, zip(*cm(5, 1.5, timeout=cm_timeout))) - except TimeoutError: - logger.warning('TimeoutError: No new WCS solution found') - except StopIteration: - logger.warning('StopIterationError: No new WCS solution found') - else: - logger.info('Found new WCS') - image.wcs = fit_wcs_from_points( - np.array(list(zip(*cm.xy[i_xy]))), - coords.SkyCoord(*map(list, zip(*cm.rd[i_rd])), unit='deg')) - del i_xy, i_rd - - used_wcs = str(WCS2.from_astropy_wcs(image.wcs)) - logger.debug('Used WCS: %s', used_wcs) - - # Calculate pixel-coordinates of references: - xy = image.wcs.all_world2pix(list(zip(refs_coord.ra.deg, refs_coord.dec.deg)), 0) - references['pixel_column'], references['pixel_row'] = x, y = list(map(np.array, zip(*xy))) - - # Clean out the references: - hsize = 10 - clean_references = references[ - (target_coord.separation(refs_coord) > ref_target_dist_limit) - & (x > hsize) & (x < (image.shape[1] - 1 - hsize)) - & (y > hsize) & (y < (image.shape[0] - 1 - hsize))] - - if not clean_references: - raise RuntimeError('No clean references in field') - - # Calculate the targets position in the image: - target_pixel_pos = image.wcs.all_world2pix([(target['ra'], target['decl'])], 0)[0] - - # Clean reference star locations - masked_fwhms, masked_ref_xys, rsq_mask, masked_rsqs = refclean.force_reject_g2d( - clean_references['pixel_column'], - clean_references['pixel_row'], - image, - get_fwhm=True, - radius=radius, - fwhm_guess=fwhm_guess, - fwhm_max=fwhm_max, - fwhm_min=fwhm_min, - rsq_min=0.15) - - # Use R^2 to more robustly determine initial FWHM guess. - # This cleaning is good when we have FEW references. - fwhm, fwhm_clean_references = refclean.clean_with_rsq_and_get_fwhm( - masked_fwhms, - masked_rsqs, - clean_references, - min_fwhm_references=2, - min_references=6, - rsq_min=0.15) - logger.info('Initial FWHM guess is %f pixels', fwhm) - - # Create plot of target and reference star positions from 2D Gaussian fits. - fig, ax = plt.subplots(1, 1, figsize=(20, 18)) - plot_image(image.subclean, ax=ax, scale='log', cbar='right', title=target_name) - ax.scatter(fwhm_clean_references['pixel_column'], fwhm_clean_references['pixel_row'], c='r', marker='o', alpha=0.3) - ax.scatter(masked_sep_xy[:, 0], masked_sep_xy[:, 1], marker='s', alpha=1.0, edgecolors='green', facecolors='none') - ax.scatter(target_pixel_pos[0], target_pixel_pos[1], marker='+', s=20, c='r') - fig.savefig(os.path.join(output_folder, 'positions_g2d.png'), bbox_inches='tight') - plt.close(fig) - - # Final clean of wcs corrected references - logger.info("Number of references before final cleaning: %d", len(clean_references)) - logger.debug('Masked R^2 values: %s', masked_rsqs[rsq_mask]) - references = refclean.get_clean_references(clean_references, masked_rsqs, rsq_ideal=0.8) - logger.info("Number of references after final cleaning: %d", len(references)) - - # Create plot of target and reference star positions: - fig, ax = plt.subplots(1, 1, figsize=(20, 18)) - plot_image(image.subclean, ax=ax, scale='log', cbar='right', title=target_name) - ax.scatter(references['pixel_column'], references['pixel_row'], c='r', marker='o', alpha=0.6) - ax.scatter(masked_sep_xy[:, 0], masked_sep_xy[:, 1], marker='s', alpha=0.6, edgecolors='green', facecolors='none') - ax.scatter(target_pixel_pos[0], target_pixel_pos[1], marker='+', s=20, c='r') - fig.savefig(os.path.join(output_folder, 'positions.png'), bbox_inches='tight') - plt.close(fig) - - # Cleanup large arrays which are no longer needed: - del fig, ax, cm - gc.collect() - - #============================================================================================== - # CREATE EFFECTIVE PSF MODEL - #============================================================================================== - - # Make cutouts of stars using extract_stars: - # Scales with FWHM - size = int(np.round(29 * fwhm / 6)) - size += 0 if size % 2 else 1 # Make sure it's a uneven number - size = max(size, 15) # Never go below 15 pixels - - # Extract stars sub-images: - xy = [tuple(masked_ref_xys[clean_references['starid'] == ref['starid']].data[0]) for ref in references] - with warnings.catch_warnings(): - warnings.simplefilter('ignore', AstropyUserWarning) - stars = extract_stars( - NDData(data=image.subclean.data, mask=image.mask), - Table(np.array(xy), names=('x', 'y')), - size=size + 6 # +6 for edge buffer - ) - - logger.info("Number of stars input to ePSF builder: %d", len(stars)) - - # Plot the stars being used for ePSF: - imgnr = 0 - nrows, ncols = 5, 5 - for k in range(int(np.ceil(len(stars) / (nrows * ncols)))): - fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(20, 20), squeeze=True) - ax = ax.ravel() - for i in range(nrows * ncols): - if imgnr > len(stars) - 1: - ax[i].axis('off') - else: - plot_image(stars[imgnr], ax=ax[i], scale='log', cmap='viridis') # FIXME (no x-ticks) - imgnr += 1 - - fig.savefig(os.path.join(output_folder, f'epsf_stars{k+1:02d}.png'), bbox_inches='tight') - plt.close(fig) - - # Build the ePSF: - epsf, stars = FlowsEPSFBuilder( - oversampling=1, - shape=1 * size, - fitter=EPSFFitter(fit_boxsize=max(int(np.round(1.5 * fwhm)), 5)), - recentering_boxsize=max(int(np.round(2 * fwhm)), 5), - norm_radius=max(fwhm, 5), - maxiters=100, - progress_bar=logger.isEnabledFor(logging.INFO) - )(stars) - logger.info('Built PSF model (%(n_iter)d/%(max_iters)d) in %(time).1f seconds', epsf.fit_info) - - # Store which stars were used in ePSF in the table: - references['used_for_epsf'] = False - references['used_for_epsf'][[star.id_label - 1 for star in stars.all_good_stars]] = True - logger.info("Number of stars used for ePSF: %d", np.sum(references['used_for_epsf'])) - - fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 15)) - plot_image(epsf.data, ax=ax1, cmap='viridis') - - fwhms = [] - bad_epsf_detected = False - for a, ax in ((0, ax3), (1, ax2)): - # Collapse the PDF along this axis: - profile = epsf.data.sum(axis=a) - itop = profile.argmax() - poffset = profile[itop] / 2 - - # Run a spline through the points, but subtract half of the peak value, and find the roots: - # We have to use a cubic spline, since roots() is not supported for other splines - # for some reason - profile_intp = UnivariateSpline(np.arange(0, len(profile)), profile - poffset, - k=3, s=0, ext=3) - lr = profile_intp.roots() - - # Plot the profile and spline: - x_fine = np.linspace(-0.5, len(profile) - 0.5, 500) - ax.plot(profile, 'k.-') - ax.plot(x_fine, profile_intp(x_fine) + poffset, 'g-') - ax.axvline(itop) - ax.set_xlim(-0.5, len(profile) - 0.5) - - # Do some sanity checks on the ePSF: - # It should pass 50% exactly twice and have the maximum inside that region. - # I.e. it should be a single gaussian-like peak - if len(lr) != 2 or itop < lr[0] or itop > lr[1]: - logger.error("Bad PSF along axis %d", a) - bad_epsf_detected = True - else: - axis_fwhm = lr[1] - lr[0] - fwhms.append(axis_fwhm) - ax.axvspan(lr[0], lr[1], facecolor='g', alpha=0.2) - - # Save the ePSF figure: - ax4.axis('off') - fig.savefig(os.path.join(output_folder, 'epsf.png'), bbox_inches='tight') - plt.close(fig) - - # There was a problem with the ePSF: - if bad_epsf_detected: - raise RuntimeError("Bad ePSF detected.") - - # Let's make the final FWHM the largest one we found: - fwhm = np.max(fwhms) - logger.info("Final FWHM based on ePSF: %f", fwhm) - - # Cleanup large arrays which are no longer needed: - del fig, ax, stars, fwhms, profile_intp - gc.collect() - - #============================================================================================== - # COORDINATES TO DO PHOTOMETRY AT - #============================================================================================== - - coordinates = np.array([[ref['pixel_column'], ref['pixel_row']] for ref in references]) - - # Add the main target position as the first entry for doing photometry directly in the - # science image: - coordinates = np.concatenate(([target_pixel_pos], coordinates), axis=0) - - #============================================================================================== - # APERTURE PHOTOMETRY - #============================================================================================== - - # Define apertures for aperture photometry: - apertures = CircularAperture(coordinates, r=fwhm) - annuli = CircularAnnulus(coordinates, r_in=1.5*fwhm, r_out=2.5*fwhm) - - apphot_tbl = aperture_photometry(image.subclean, [apertures, annuli], - mask=image.mask, error=image.error) - - logger.info('Aperture Photometry Success') - logger.debug("Aperture Photometry Table:\n%s", apphot_tbl) - - #============================================================================================== - # PSF PHOTOMETRY - #============================================================================================== - - # Create photometry object: - photometry_obj = BasicPSFPhotometry( - group_maker=DAOGroup(fwhm), - bkg_estimator=MedianBackground(), - psf_model=epsf, - fitter=fitting.LevMarLSQFitter(), - fitshape=size, - aperture_radius=fwhm) - - psfphot_tbl = photometry_obj(image=image.subclean, - init_guesses=Table(coordinates, names=['x_0', 'y_0'])) - - logger.info('PSF Photometry Success') - logger.debug("PSF Photometry Table:\n%s", psfphot_tbl) - - #============================================================================================== - # TEMPLATE SUBTRACTION AND TARGET PHOTOMETRY - #============================================================================================== - - # Find the pixel-scale of the science image: - pixel_area = proj_plane_pixel_area(image.wcs.celestial) - pixel_scale = np.sqrt(pixel_area) * 3600 # arcsec/pixel - # print(image.wcs.celestial.cunit) % Doesn't work? - logger.info("Science image pixel scale: %f", pixel_scale) - - diffimage = None - if datafile.get('diffimg') is not None: - diffimg_path = os.path.join(datafile['archive_path'], datafile['diffimg']['path']) - diffimg = load_image(diffimg_path) - diffimage = diffimg.image - - elif attempt_imagematch and datafile.get('template') is not None: - # Run the template subtraction, and get back - # the science image where the template has been subtracted: - diffimage = run_imagematch(datafile, target, - star_coord=coordinates, - fwhm=fwhm, - pixel_scale=pixel_scale) - - # We have a diff image, so let's do photometry of the target using this: - if diffimage is not None: - # Include mask from original image: - diffimage = np.ma.masked_array(diffimage, image.mask) - - # Create apertures around the target: - apertures = CircularAperture(target_pixel_pos, r=fwhm) - annuli = CircularAnnulus(target_pixel_pos, r_in=1.5*fwhm, r_out=2.5*fwhm) - - # Create two plots of the difference image: - fig, ax = plt.subplots(1, 1, squeeze=True, figsize=(20, 20)) - plot_image(diffimage, ax=ax, cbar='right', title=target_name) - ax.plot(target_pixel_pos[0], target_pixel_pos[1], marker='+', markersize=20, color='r') - fig.savefig(os.path.join(output_folder, 'diffimg.png'), bbox_inches='tight') - apertures.plot(axes=ax, color='r', lw=2) - annuli.plot(axes=ax, color='r', lw=2) - ax.set_xlim(target_pixel_pos[0] - 50, target_pixel_pos[0] + 50) - ax.set_ylim(target_pixel_pos[1] - 50, target_pixel_pos[1] + 50) - fig.savefig(os.path.join(output_folder, 'diffimg_zoom.png'), bbox_inches='tight') - plt.close(fig) - - # Run aperture photometry on subtracted image: - target_apphot_tbl = aperture_photometry(diffimage, [apertures, annuli], - mask=image.mask, - error=image.error) - - # Make target only photometry object if keep_diff_fixed = True - if keep_diff_fixed: - epsf.fixed.update({'x_0': True, 'y_0': True}) - - # TODO: Try iteraratively subtracted photometry - # Create photometry object: - photometry_obj = BasicPSFPhotometry( - group_maker=DAOGroup(0.0001), - bkg_estimator=MedianBackground(), - psf_model=epsf, - fitter=fitting.LevMarLSQFitter(), - fitshape=size, - aperture_radius=fwhm) - - # Run PSF photometry on template subtracted image: - target_psfphot_tbl = photometry_obj(diffimage, - init_guesses=Table(target_pixel_pos, names=['x_0', 'y_0'])) - - # Need to adjust table columns if x_0 and y_0 were fixed - if keep_diff_fixed: - target_psfphot_tbl['x_0_unc'] = 0.0 - target_psfphot_tbl['y_0_unc'] = 0.0 - - # Combine the output tables from the target and the reference stars into one: - apphot_tbl = vstack([target_apphot_tbl, apphot_tbl], join_type='exact') - psfphot_tbl = vstack([target_psfphot_tbl, psfphot_tbl], join_type='exact') - - # Build results table: - tab = references.copy() - - row = { - 'starid': 0, - 'ra': target['ra'], - 'decl': target['decl'], - 'pixel_column': target_pixel_pos[0], - 'pixel_row': target_pixel_pos[1], - 'used_for_epsf': False - } - row.update([(k, np.NaN) for k in set(tab.keys()) - set(row) - {'gaia_variability'}]) - tab.insert_row(0, row) - - if diffimage is not None: - row['starid'] = -1 - tab.insert_row(0, row) - - indx_main_target = tab['starid'] <= 0 - - # Subtract background estimated from annuli: - flux_aperture = apphot_tbl['aperture_sum_0'] - (apphot_tbl['aperture_sum_1'] / annuli.area) * apertures.area - flux_aperture_error = np.sqrt(apphot_tbl['aperture_sum_err_0']**2 + (apphot_tbl['aperture_sum_err_1'] / annuli.area * apertures.area)**2) - - # Add table columns with results: - tab['flux_aperture'] = flux_aperture / image.exptime - tab['flux_aperture_error'] = flux_aperture_error / image.exptime - tab['flux_psf'] = psfphot_tbl['flux_fit'] / image.exptime - tab['flux_psf_error'] = psfphot_tbl['flux_unc'] / image.exptime - tab['pixel_column_psf_fit'] = psfphot_tbl['x_fit'] - tab['pixel_row_psf_fit'] = psfphot_tbl['y_fit'] - tab['pixel_column_psf_fit_error'] = psfphot_tbl['x_0_unc'] - tab['pixel_row_psf_fit_error'] = psfphot_tbl['y_0_unc'] - - # Check that we got valid photometry: - if np.any(~np.isfinite(tab[indx_main_target]['flux_psf'])) or np.any(~np.isfinite(tab[indx_main_target]['flux_psf_error'])): - raise RuntimeError("Target magnitude is undefined.") - - #============================================================================================== - # CALIBRATE - #============================================================================================== - - # Convert PSF fluxes to magnitudes: - mag_inst = -2.5 * np.log10(tab['flux_psf']) - mag_inst_err = (2.5 / np.log(10)) * (tab['flux_psf_error'] / tab['flux_psf']) - - # Corresponding magnitudes in catalog: - mag_catalog = tab[ref_filter] - - # Mask out things that should not be used in calibration: - use_for_calibration = np.ones_like(mag_catalog, dtype='bool') - use_for_calibration[indx_main_target] = False # Do not use target for calibration - use_for_calibration[~np.isfinite(mag_inst) | ~np.isfinite(mag_catalog)] = False - - # Just creating some short-hands: - x = mag_catalog[use_for_calibration] - y = mag_inst[use_for_calibration] - yerr = mag_inst_err[use_for_calibration] - weights = 1.0 / yerr**2 - - if not any(use_for_calibration): - raise RuntimeError("No calibration stars") - - # Fit linear function with fixed slope, using sigma-clipping: - model = models.Linear1D(slope=1, fixed={'slope': True}) - fitter = fitting.FittingWithOutlierRemoval( - fitting.LinearLSQFitter(), - sigma_clip, - sigma=3.0) - best_fit, sigma_clipped = fitter(model, x, y, weights=weights) - - # Extract zero-point and estimate its error using a single weighted fit: - # I don't know why there is not an error-estimate attached directly to the Parameter? - zp = -1 * best_fit.intercept.value # Negative, because that is the way zeropoints are usually defined - - weights[sigma_clipped] = 0 # Trick to make following expression simpler - n_weights = len(weights.nonzero()[0]) - if n_weights > 1: - zp_error = np.sqrt(n_weights * nansum(weights * (y - best_fit(x))**2) / nansum(weights) / (n_weights - 1)) - else: - zp_error = np.NaN - logger.info('Leastsquare ZP = %.3f, ZP_error = %.3f', zp, zp_error) - - # Determine sigma clipping sigma according to Chauvenet method - # But don't allow less than sigma = sigmamin, setting to 1.5 for now. - # Should maybe be 2? - sigmamin = 1.5 - sig_chauv = sigma_from_Chauvenet(len(x)) - sig_chauv = sig_chauv if sig_chauv >= sigmamin else sigmamin - - # Extract zero point and error using bootstrap method - nboot = 1000 - logger.info('Running bootstrap with sigma = %.2f and n = %d', sig_chauv, nboot) - pars = bootstrap_outlier(x, y, yerr, - n=nboot, - model=model, - fitter=fitting.LinearLSQFitter, - outlier=sigma_clip, - outlier_kwargs={'sigma': sig_chauv}, - summary='median', - error='bootstrap', - return_vals=False) - - zp_bs = pars['intercept'] * -1.0 - zp_error_bs = pars['intercept_error'] - - logger.info('Bootstrapped ZP = %.3f, ZP_error = %.3f', zp_bs, zp_error_bs) - - # Check that difference is not large - zp_diff = 0.4 - if np.abs(zp_bs - zp) >= zp_diff: - logger.warning("Bootstrap and weighted LSQ ZPs differ by %.2f, \ + +# -------------------------------------------------------------------------------------------------- +def photometry(fileid, output_folder=None, attempt_imagematch=True, keep_diff_fixed=False, cm_timeout=None): + """ + Run photometry. + + Parameters: + fileid (int): File ID to process. + output_folder (str, optional): Path to directory where output should be placed. + attempt_imagematch (bool, optional): If no subtracted image is available, but a + template image is, should we attempt to run ImageMatch using standard settings. + Default=True. + keep_diff_fixed (bool, optional): Allow psf photometry to recenter when + calculating the flux for the difference image. Setting to True can help if diff + image has non-source flux in the region around the SN. + cm_timeout (float, optional): Timeout in seconds for the :class:`CoordinateMatch` algorithm. + + .. codeauthor:: Rasmus Handberg + .. codeauthor:: Emir Karamehmetoglu + .. codeauthor:: Simon Holmbo + """ + + # Settings: + ref_target_dist_limit = 10 * u.arcsec # Reference star must be further than this away to be included + + logger = logging.getLogger(__name__) + tic = default_timer() + + # Use local copy of archive if configured to do so: + config = load_config() + + # Get datafile dict from API: + datafile = api.get_datafile(fileid) + logger.debug("Datafile: %s", datafile) + targetid = datafile['targetid'] + target_name = datafile['target_name'] + photfilter = datafile['photfilter'] + + archive_local = config.get('photometry', 'archive_local', fallback=None) + if archive_local is not None: + datafile['archive_path'] = archive_local + if not os.path.isdir(datafile['archive_path']): + raise FileNotFoundError("ARCHIVE is not available: " + datafile['archive_path']) + + # Get the catalog containing the target and reference stars: + # TODO: Include proper-motion to the time of observation + catalog = api.get_catalog(targetid, output='table') + target = catalog['target'][0] + target_coord = coords.SkyCoord(ra=target['ra'], dec=target['decl'], unit='deg', frame='icrs') + + # Folder to save output: + if output_folder is None: + output_folder_root = config.get('photometry', 'output', fallback='.') + output_folder = os.path.join(output_folder_root, target_name, f'{fileid:05d}') + logger.info("Placing output in '%s'", output_folder) + os.makedirs(output_folder, exist_ok=True) + + # The paths to the science image: + filepath = os.path.join(datafile['archive_path'], datafile['path']) + + # TODO: Download datafile using API to local drive: + # TODO: Is this a security concern? + # if archive_local: + # api.download_datafile(datafile, archive_local) + + # Translate photometric filter into table column: + ref_filter = {'up': 'u_mag', 'gp': 'g_mag', 'rp': 'r_mag', 'ip': 'i_mag', 'zp': 'z_mag', 'B': 'B_mag', 'V': 'V_mag', + 'J': 'J_mag', 'H': 'H_mag', 'K': 'K_mag', }.get(photfilter, None) + + if ref_filter is None: + logger.warning("Could not find filter '%s' in catalogs. Using default gp filter.", photfilter) + ref_filter = 'g_mag' + + # Load the image from the FITS file: + logger.info("Load image '%s'", filepath) + image = load_image(filepath, target_coord=target_coord) + + references = catalog['references'] + references.sort(ref_filter) + + # Check that there actually are reference stars in that filter: + if allnan(references[ref_filter]): + raise ValueError("No reference stars found in current photfilter.") + + # ============================================================================================== + # BARYCENTRIC CORRECTION OF TIME + # ============================================================================================== + + ltt_bary = image.obstime.light_travel_time(target_coord, ephemeris='jpl') + image.obstime = image.obstime.tdb + ltt_bary + + # ============================================================================================== + # BACKGROUND ESTIMATION + # ============================================================================================== + + fig, ax = plt.subplots(1, 2, figsize=(20, 18)) + plot_image(image.clean, ax=ax[0], scale='log', cbar='right', title='Image') + plot_image(image.mask, ax=ax[1], scale='linear', cbar='right', title='Mask') + fig.savefig(os.path.join(output_folder, 'original.png'), bbox_inches='tight') + plt.close(fig) + + # Estimate image background: + # Not using image.clean here, since we are redefining the mask anyway + background = Background2D(image.clean, (128, 128), filter_size=(5, 5), sigma_clip=SigmaClip(sigma=3.0), + bkg_estimator=SExtractorBackground(), exclude_percentile=50.0) + + # Create background-subtracted image: + image.subclean = image.clean - background.background + + # Plot background estimation: + fig, ax = plt.subplots(1, 3, figsize=(20, 6)) + plot_image(image.clean, ax=ax[0], scale='log', title='Original') + plot_image(background.background, ax=ax[1], scale='log', title='Background') + plot_image(image.subclean, ax=ax[2], scale='log', title='Background subtracted') + fig.savefig(os.path.join(output_folder, 'background.png'), bbox_inches='tight') + plt.close(fig) + + # TODO: Is this correct?! + image.error = calc_total_error(image.clean, background.background_rms, 1.0) + + # Use sep to for soure extraction + sep_background = sep.Background(image.image, mask=image.mask) + objects = sep.extract(image.image - sep_background, thresh=5., err=sep_background.globalrms, mask=image.mask, + deblend_cont=0.1, minarea=9, clean_param=2.0) + + # Cleanup large arrays which are no longer needed: + del background, fig, ax, sep_background, ltt_bary + gc.collect() + + # ============================================================================================== + # DETECTION OF STARS AND MATCHING WITH CATALOG + # ============================================================================================== + + # Account for proper motion: + replace(references['pm_ra'], np.NaN, 0) + replace(references['pm_dec'], np.NaN, 0) + refs_coord = coords.SkyCoord(ra=references['ra'], dec=references['decl'], pm_ra_cosdec=references['pm_ra'], + pm_dec=references['pm_dec'], unit='deg', frame='icrs', + obstime=Time(2015.5, format='decimalyear')) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", ErfaWarning) + refs_coord = refs_coord.apply_space_motion(new_obstime=image.obstime) + + # TODO: These need to be based on the instrument! + radius = 10 + fwhm_guess = 6.0 + fwhm_min = 3.5 + fwhm_max = 18.0 + + # Clean extracted stars + masked_sep_xy, sep_mask, masked_sep_rsqs = refclean.force_reject_g2d(objects['x'], objects['y'], image, + get_fwhm=False, radius=radius, + fwhm_guess=fwhm_guess, rsq_min=0.3, + fwhm_max=fwhm_max, fwhm_min=fwhm_min) + + logger.info("Finding new WCS solution...") + head_wcs = str(WCS2.from_astropy_wcs(image.wcs)) + logger.debug('Head WCS: %s', head_wcs) + + # Solve for new WCS + cm = CoordinateMatch(xy=list(masked_sep_xy[sep_mask]), rd=list(zip(refs_coord.ra.deg, refs_coord.dec.deg)), + xy_order=np.argsort( + np.power(masked_sep_xy[sep_mask] - np.array(image.shape[::-1]) / 2, 2).sum(axis=1)), + rd_order=np.argsort(target_coord.separation(refs_coord)), xy_nmax=100, rd_nmax=100, + maximum_angle_distance=0.002) + + # Set timeout par to infinity unless specified. + if cm_timeout is None: + cm_timeout = float('inf') + try: + i_xy, i_rd = map(np.array, zip(*cm(5, 1.5, timeout=cm_timeout))) + except TimeoutError: + logger.warning('TimeoutError: No new WCS solution found') + except StopIteration: + logger.warning('StopIterationError: No new WCS solution found') + else: + logger.info('Found new WCS') + image.wcs = fit_wcs_from_points(np.array(list(zip(*cm.xy[i_xy]))), + coords.SkyCoord(*map(list, zip(*cm.rd[i_rd])), unit='deg')) + del i_xy, i_rd + + used_wcs = str(WCS2.from_astropy_wcs(image.wcs)) + logger.debug('Used WCS: %s', used_wcs) + + # Calculate pixel-coordinates of references: + xy = image.wcs.all_world2pix(list(zip(refs_coord.ra.deg, refs_coord.dec.deg)), 0) + references['pixel_column'], references['pixel_row'] = x, y = list(map(np.array, zip(*xy))) + + # Clean out the references: + hsize = 10 + clean_references = references[(target_coord.separation(refs_coord) > ref_target_dist_limit) & (x > hsize) & ( + x < (image.shape[1] - 1 - hsize)) & (y > hsize) & (y < (image.shape[0] - 1 - hsize))] + + if not clean_references: + raise RuntimeError('No clean references in field') + + # Calculate the targets position in the image: + target_pixel_pos = image.wcs.all_world2pix([(target['ra'], target['decl'])], 0)[0] + + # Clean reference star locations + masked_fwhms, masked_ref_xys, rsq_mask, masked_rsqs = refclean.force_reject_g2d(clean_references['pixel_column'], + clean_references['pixel_row'], + image, get_fwhm=True, radius=radius, + fwhm_guess=fwhm_guess, + fwhm_max=fwhm_max, + fwhm_min=fwhm_min, rsq_min=0.15) + + # Use R^2 to more robustly determine initial FWHM guess. + # This cleaning is good when we have FEW references. + fwhm, fwhm_clean_references = refclean.clean_with_rsq_and_get_fwhm(masked_fwhms, masked_rsqs, clean_references, + min_fwhm_references=2, min_references=6, + rsq_min=0.15) + logger.info('Initial FWHM guess is %f pixels', fwhm) + + # Create plot of target and reference star positions from 2D Gaussian fits. + fig, ax = plt.subplots(1, 1, figsize=(20, 18)) + plot_image(image.subclean, ax=ax, scale='log', cbar='right', title=target_name) + ax.scatter(fwhm_clean_references['pixel_column'], fwhm_clean_references['pixel_row'], c='r', marker='o', alpha=0.3) + ax.scatter(masked_sep_xy[:, 0], masked_sep_xy[:, 1], marker='s', alpha=1.0, edgecolors='green', facecolors='none') + ax.scatter(target_pixel_pos[0], target_pixel_pos[1], marker='+', s=20, c='r') + fig.savefig(os.path.join(output_folder, 'positions_g2d.png'), bbox_inches='tight') + plt.close(fig) + + # Final clean of wcs corrected references + logger.info("Number of references before final cleaning: %d", len(clean_references)) + logger.debug('Masked R^2 values: %s', masked_rsqs[rsq_mask]) + references = refclean.get_clean_references(clean_references, masked_rsqs, rsq_ideal=0.8) + logger.info("Number of references after final cleaning: %d", len(references)) + + # Create plot of target and reference star positions: + fig, ax = plt.subplots(1, 1, figsize=(20, 18)) + plot_image(image.subclean, ax=ax, scale='log', cbar='right', title=target_name) + ax.scatter(references['pixel_column'], references['pixel_row'], c='r', marker='o', alpha=0.6) + ax.scatter(masked_sep_xy[:, 0], masked_sep_xy[:, 1], marker='s', alpha=0.6, edgecolors='green', facecolors='none') + ax.scatter(target_pixel_pos[0], target_pixel_pos[1], marker='+', s=20, c='r') + fig.savefig(os.path.join(output_folder, 'positions.png'), bbox_inches='tight') + plt.close(fig) + + # Cleanup large arrays which are no longer needed: + del fig, ax, cm + gc.collect() + + # ============================================================================================== + # CREATE EFFECTIVE PSF MODEL + # ============================================================================================== + + # Make cutouts of stars using extract_stars: + # Scales with FWHM + size = int(np.round(29 * fwhm / 6)) + size += 0 if size % 2 else 1 # Make sure it's a uneven number + size = max(size, 15) # Never go below 15 pixels + + # Extract stars sub-images: + xy = [tuple(masked_ref_xys[clean_references['starid'] == ref['starid']].data[0]) for ref in references] + with warnings.catch_warnings(): + warnings.simplefilter('ignore', AstropyUserWarning) + stars = extract_stars(NDData(data=image.subclean.data, mask=image.mask), Table(np.array(xy), names=('x', 'y')), + size=size + 6 # +6 for edge buffer + ) + + logger.info("Number of stars input to ePSF builder: %d", len(stars)) + + # Plot the stars being used for ePSF: + imgnr = 0 + nrows, ncols = 5, 5 + for k in range(int(np.ceil(len(stars) / (nrows * ncols)))): + fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(20, 20), squeeze=True) + ax = ax.ravel() + for i in range(nrows * ncols): + if imgnr > len(stars) - 1: + ax[i].axis('off') + else: + plot_image(stars[imgnr], ax=ax[i], scale='log', cmap='viridis') # FIXME (no x-ticks) + imgnr += 1 + + fig.savefig(os.path.join(output_folder, f'epsf_stars{k + 1:02d}.png'), bbox_inches='tight') + plt.close(fig) + + # Build the ePSF: + epsf, stars = FlowsEPSFBuilder(oversampling=1, shape=1 * size, + fitter=EPSFFitter(fit_boxsize=max(int(np.round(1.5 * fwhm)), 5)), + recentering_boxsize=max(int(np.round(2 * fwhm)), 5), norm_radius=max(fwhm, 5), + maxiters=100, progress_bar=logger.isEnabledFor(logging.INFO))(stars) + logger.info('Built PSF model (%(n_iter)d/%(max_iters)d) in %(time).1f seconds', epsf.fit_info) + + # Store which stars were used in ePSF in the table: + references['used_for_epsf'] = False + references['used_for_epsf'][[star.id_label - 1 for star in stars.all_good_stars]] = True + logger.info("Number of stars used for ePSF: %d", np.sum(references['used_for_epsf'])) + + fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 15)) + plot_image(epsf.data, ax=ax1, cmap='viridis') + + fwhms = [] + bad_epsf_detected = False + for a, ax in ((0, ax3), (1, ax2)): + # Collapse the PDF along this axis: + profile = epsf.data.sum(axis=a) + itop = profile.argmax() + poffset = profile[itop] / 2 + + # Run a spline through the points, but subtract half of the peak value, and find the roots: + # We have to use a cubic spline, since roots() is not supported for other splines + # for some reason + profile_intp = UnivariateSpline(np.arange(0, len(profile)), profile - poffset, k=3, s=0, ext=3) + lr = profile_intp.roots() + + # Plot the profile and spline: + x_fine = np.linspace(-0.5, len(profile) - 0.5, 500) + ax.plot(profile, 'k.-') + ax.plot(x_fine, profile_intp(x_fine) + poffset, 'g-') + ax.axvline(itop) + ax.set_xlim(-0.5, len(profile) - 0.5) + + # Do some sanity checks on the ePSF: + # It should pass 50% exactly twice and have the maximum inside that region. + # I.e. it should be a single gaussian-like peak + if len(lr) != 2 or itop < lr[0] or itop > lr[1]: + logger.error("Bad PSF along axis %d", a) + bad_epsf_detected = True + else: + axis_fwhm = lr[1] - lr[0] + fwhms.append(axis_fwhm) + ax.axvspan(lr[0], lr[1], facecolor='g', alpha=0.2) + + # Save the ePSF figure: + ax4.axis('off') + fig.savefig(os.path.join(output_folder, 'epsf.png'), bbox_inches='tight') + plt.close(fig) + + # There was a problem with the ePSF: + if bad_epsf_detected: + raise RuntimeError("Bad ePSF detected.") + + # Let's make the final FWHM the largest one we found: + fwhm = np.max(fwhms) + logger.info("Final FWHM based on ePSF: %f", fwhm) + + # Cleanup large arrays which are no longer needed: + del fig, ax, stars, fwhms, profile_intp + gc.collect() + + # ============================================================================================== + # COORDINATES TO DO PHOTOMETRY AT + # ============================================================================================== + + coordinates = np.array([[ref['pixel_column'], ref['pixel_row']] for ref in references]) + + # Add the main target position as the first entry for doing photometry directly in the + # science image: + coordinates = np.concatenate(([target_pixel_pos], coordinates), axis=0) + + # ============================================================================================== + # APERTURE PHOTOMETRY + # ============================================================================================== + + # Define apertures for aperture photometry: + apertures = CircularAperture(coordinates, r=fwhm) + annuli = CircularAnnulus(coordinates, r_in=1.5 * fwhm, r_out=2.5 * fwhm) + + apphot_tbl = aperture_photometry(image.subclean, [apertures, annuli], mask=image.mask, error=image.error) + + logger.info('Aperture Photometry Success') + logger.debug("Aperture Photometry Table:\n%s", apphot_tbl) + + # ============================================================================================== + # PSF PHOTOMETRY + # ============================================================================================== + + # Create photometry object: + photometry_obj = BasicPSFPhotometry(group_maker=DAOGroup(fwhm), bkg_estimator=MedianBackground(), psf_model=epsf, + fitter=fitting.LevMarLSQFitter(), fitshape=size, aperture_radius=fwhm) + + psfphot_tbl = photometry_obj(image=image.subclean, init_guesses=Table(coordinates, names=['x_0', 'y_0'])) + + logger.info('PSF Photometry Success') + logger.debug("PSF Photometry Table:\n%s", psfphot_tbl) + + # ============================================================================================== + # TEMPLATE SUBTRACTION AND TARGET PHOTOMETRY + # ============================================================================================== + + # Find the pixel-scale of the science image: + pixel_area = proj_plane_pixel_area(image.wcs.celestial) + pixel_scale = np.sqrt(pixel_area) * 3600 # arcsec/pixel + # print(image.wcs.celestial.cunit) % Doesn't work? + logger.info("Science image pixel scale: %f", pixel_scale) + + diffimage = None + if datafile.get('diffimg') is not None: + diffimg_path = os.path.join(datafile['archive_path'], datafile['diffimg']['path']) + diffimg = load_image(diffimg_path) + diffimage = diffimg.image + + elif attempt_imagematch and datafile.get('template') is not None: + # Run the template subtraction, and get back + # the science image where the template has been subtracted: + diffimage = run_imagematch(datafile, target, star_coord=coordinates, fwhm=fwhm, pixel_scale=pixel_scale) + + # We have a diff image, so let's do photometry of the target using this: + if diffimage is not None: + # Include mask from original image: + diffimage = np.ma.masked_array(diffimage, image.mask) + + # Create apertures around the target: + apertures = CircularAperture(target_pixel_pos, r=fwhm) + annuli = CircularAnnulus(target_pixel_pos, r_in=1.5 * fwhm, r_out=2.5 * fwhm) + + # Create two plots of the difference image: + fig, ax = plt.subplots(1, 1, squeeze=True, figsize=(20, 20)) + plot_image(diffimage, ax=ax, cbar='right', title=target_name) + ax.plot(target_pixel_pos[0], target_pixel_pos[1], marker='+', markersize=20, color='r') + fig.savefig(os.path.join(output_folder, 'diffimg.png'), bbox_inches='tight') + apertures.plot(axes=ax, color='r', lw=2) + annuli.plot(axes=ax, color='r', lw=2) + ax.set_xlim(target_pixel_pos[0] - 50, target_pixel_pos[0] + 50) + ax.set_ylim(target_pixel_pos[1] - 50, target_pixel_pos[1] + 50) + fig.savefig(os.path.join(output_folder, 'diffimg_zoom.png'), bbox_inches='tight') + plt.close(fig) + + # Run aperture photometry on subtracted image: + target_apphot_tbl = aperture_photometry(diffimage, [apertures, annuli], mask=image.mask, error=image.error) + + # Make target only photometry object if keep_diff_fixed = True + if keep_diff_fixed: + epsf.fixed.update({'x_0': True, 'y_0': True}) + + # TODO: Try iteraratively subtracted photometry + # Create photometry object: + photometry_obj = BasicPSFPhotometry(group_maker=DAOGroup(0.0001), bkg_estimator=MedianBackground(), + psf_model=epsf, fitter=fitting.LevMarLSQFitter(), fitshape=size, + aperture_radius=fwhm) + + # Run PSF photometry on template subtracted image: + target_psfphot_tbl = photometry_obj(diffimage, init_guesses=Table(target_pixel_pos, names=['x_0', 'y_0'])) + + # Need to adjust table columns if x_0 and y_0 were fixed + if keep_diff_fixed: + target_psfphot_tbl['x_0_unc'] = 0.0 + target_psfphot_tbl['y_0_unc'] = 0.0 + + # Combine the output tables from the target and the reference stars into one: + apphot_tbl = vstack([target_apphot_tbl, apphot_tbl], join_type='exact') + psfphot_tbl = vstack([target_psfphot_tbl, psfphot_tbl], join_type='exact') + + # Build results table: + tab = references.copy() + + row = {'starid': 0, 'ra': target['ra'], 'decl': target['decl'], 'pixel_column': target_pixel_pos[0], + 'pixel_row': target_pixel_pos[1], 'used_for_epsf': False} + row.update([(k, np.NaN) for k in set(tab.keys()) - set(row) - {'gaia_variability'}]) + tab.insert_row(0, row) + + if diffimage is not None: + row['starid'] = -1 + tab.insert_row(0, row) + + indx_main_target = tab['starid'] <= 0 + + # Subtract background estimated from annuli: + flux_aperture = apphot_tbl['aperture_sum_0'] - (apphot_tbl['aperture_sum_1'] / annuli.area) * apertures.area + flux_aperture_error = np.sqrt( + apphot_tbl['aperture_sum_err_0'] ** 2 + (apphot_tbl['aperture_sum_err_1'] / annuli.area * apertures.area) ** 2) + + # Add table columns with results: + tab['flux_aperture'] = flux_aperture / image.exptime + tab['flux_aperture_error'] = flux_aperture_error / image.exptime + tab['flux_psf'] = psfphot_tbl['flux_fit'] / image.exptime + tab['flux_psf_error'] = psfphot_tbl['flux_unc'] / image.exptime + tab['pixel_column_psf_fit'] = psfphot_tbl['x_fit'] + tab['pixel_row_psf_fit'] = psfphot_tbl['y_fit'] + tab['pixel_column_psf_fit_error'] = psfphot_tbl['x_0_unc'] + tab['pixel_row_psf_fit_error'] = psfphot_tbl['y_0_unc'] + + # Check that we got valid photometry: + if np.any(~np.isfinite(tab[indx_main_target]['flux_psf'])) or np.any( + ~np.isfinite(tab[indx_main_target]['flux_psf_error'])): + raise RuntimeError("Target magnitude is undefined.") + + # ============================================================================================== + # CALIBRATE + # ============================================================================================== + + # Convert PSF fluxes to magnitudes: + mag_inst = -2.5 * np.log10(tab['flux_psf']) + mag_inst_err = (2.5 / np.log(10)) * (tab['flux_psf_error'] / tab['flux_psf']) + + # Corresponding magnitudes in catalog: + mag_catalog = tab[ref_filter] + + # Mask out things that should not be used in calibration: + use_for_calibration = np.ones_like(mag_catalog, dtype='bool') + use_for_calibration[indx_main_target] = False # Do not use target for calibration + use_for_calibration[~np.isfinite(mag_inst) | ~np.isfinite(mag_catalog)] = False + + # Just creating some short-hands: + x = mag_catalog[use_for_calibration] + y = mag_inst[use_for_calibration] + yerr = mag_inst_err[use_for_calibration] + weights = 1.0 / yerr ** 2 + + if not any(use_for_calibration): + raise RuntimeError("No calibration stars") + + # Fit linear function with fixed slope, using sigma-clipping: + model = models.Linear1D(slope=1, fixed={'slope': True}) + fitter = fitting.FittingWithOutlierRemoval(fitting.LinearLSQFitter(), sigma_clip, sigma=3.0) + best_fit, sigma_clipped = fitter(model, x, y, weights=weights) + + # Extract zero-point and estimate its error using a single weighted fit: + # I don't know why there is not an error-estimate attached directly to the Parameter? + zp = -1 * best_fit.intercept.value # Negative, because that is the way zeropoints are usually defined + + weights[sigma_clipped] = 0 # Trick to make following expression simpler + n_weights = len(weights.nonzero()[0]) + if n_weights > 1: + zp_error = np.sqrt(n_weights * nansum(weights * (y - best_fit(x)) ** 2) / nansum(weights) / (n_weights - 1)) + else: + zp_error = np.NaN + logger.info('Leastsquare ZP = %.3f, ZP_error = %.3f', zp, zp_error) + + # Determine sigma clipping sigma according to Chauvenet method + # But don't allow less than sigma = sigmamin, setting to 1.5 for now. + # Should maybe be 2? + sigmamin = 1.5 + sig_chauv = sigma_from_Chauvenet(len(x)) + sig_chauv = sig_chauv if sig_chauv >= sigmamin else sigmamin + + # Extract zero point and error using bootstrap method + nboot = 1000 + logger.info('Running bootstrap with sigma = %.2f and n = %d', sig_chauv, nboot) + pars = bootstrap_outlier(x, y, yerr, n=nboot, model=model, fitter=fitting.LinearLSQFitter, outlier=sigma_clip, + outlier_kwargs={'sigma': sig_chauv}, summary='median', error='bootstrap', + return_vals=False) + + zp_bs = pars['intercept'] * -1.0 + zp_error_bs = pars['intercept_error'] + + logger.info('Bootstrapped ZP = %.3f, ZP_error = %.3f', zp_bs, zp_error_bs) + + # Check that difference is not large + zp_diff = 0.4 + if np.abs(zp_bs - zp) >= zp_diff: + logger.warning("Bootstrap and weighted LSQ ZPs differ by %.2f, \ which is more than the allowed %.2f mag.", np.abs(zp_bs - zp), zp_diff) - # Add calibrated magnitudes to the photometry table: - tab['mag'] = mag_inst + zp_bs - tab['mag_error'] = np.sqrt(mag_inst_err**2 + zp_error_bs**2) - - fig, ax = plt.subplots(1, 1) - ax.errorbar(x, y, yerr=yerr, fmt='k.') - ax.scatter(x[sigma_clipped], y[sigma_clipped], marker='x', c='r') - ax.plot(x, best_fit(x), color='g', linewidth=3) - ax.set_xlabel('Catalog magnitude') - ax.set_ylabel('Instrumental magnitude') - fig.savefig(os.path.join(output_folder, 'calibration.png'), bbox_inches='tight') - plt.close(fig) - - # Check that we got valid photometry: - if not np.isfinite(tab[0]['mag']) or not np.isfinite(tab[0]['mag_error']): - raise RuntimeError("Target magnitude is undefined.") - - #============================================================================================== - # SAVE PHOTOMETRY - #============================================================================================== - - # Descriptions of columns: - tab['used_for_epsf'].description = 'Was object used for building ePSF?' - tab['mag'].description = 'Measured magnitude' - tab['mag'].unit = u.mag - tab['mag_error'].description = 'Error on measured magnitude' - tab['mag_error'].unit = u.mag - tab['flux_aperture'].description = 'Measured flux using aperture photometry' - tab['flux_aperture'].unit = u.count / u.second - tab['flux_aperture_error'].description = 'Error on measured flux using aperture photometry' - tab['flux_aperture_error'].unit = u.count / u.second - tab['flux_psf'].description = 'Measured flux using PSF photometry' - tab['flux_psf'].unit = u.count / u.second - tab['flux_psf_error'].description = 'Error on measured flux using PSF photometry' - tab['flux_psf_error'].unit = u.count / u.second - tab['pixel_column'].description = 'Location on image pixel columns' - tab['pixel_column'].unit = u.pixel - tab['pixel_row'].description = 'Location on image pixel rows' - tab['pixel_row'].unit = u.pixel - tab['pixel_column_psf_fit'].description = 'Measured location on image pixel columns from PSF photometry' - tab['pixel_column_psf_fit'].unit = u.pixel - tab['pixel_column_psf_fit_error'].description = 'Error on measured location on image pixel columns from PSF photometry' - tab['pixel_column_psf_fit_error'].unit = u.pixel - tab['pixel_row_psf_fit'].description = 'Measured location on image pixel rows from PSF photometry' - tab['pixel_row_psf_fit'].unit = u.pixel - tab['pixel_row_psf_fit_error'].description = 'Error on measured location on image pixel rows from PSF photometry' - tab['pixel_row_psf_fit_error'].unit = u.pixel - - # Meta-data: - tab.meta['fileid'] = fileid - tab.meta['target_name'] = target_name - tab.meta['version'] = __version__ - tab.meta['template'] = None if datafile.get('template') is None else datafile['template']['fileid'] - tab.meta['diffimg'] = None if datafile.get('diffimg') is None else datafile['diffimg']['fileid'] - tab.meta['photfilter'] = photfilter - tab.meta['fwhm'] = fwhm * u.pixel - tab.meta['pixel_scale'] = pixel_scale * u.arcsec / u.pixel - tab.meta['seeing'] = (fwhm * pixel_scale) * u.arcsec - tab.meta['obstime-bmjd'] = float(image.obstime.mjd) - tab.meta['zp'] = zp_bs - tab.meta['zp_error'] = zp_error_bs - tab.meta['zp_diff'] = np.abs(zp_bs - zp) - tab.meta['zp_error_weights'] = zp_error - tab.meta['head_wcs'] = head_wcs # TODO: Are these really useful? - tab.meta['used_wcs'] = used_wcs # TODO: Are these really useful? - - # Filepath where to save photometry: - photometry_output = os.path.join(output_folder, 'photometry.ecsv') - - # Write the final table to file: - tab.write(photometry_output, format='ascii.ecsv', delimiter=',', overwrite=True) - - toc = default_timer() - - logger.info("------------------------------------------------------") - logger.info("Success!") - logger.info("Main target: %f +/- %f", tab[0]['mag'], tab[0]['mag_error']) - logger.info("Photometry took: %.1f seconds", toc - tic) - - return photometry_output + # Add calibrated magnitudes to the photometry table: + tab['mag'] = mag_inst + zp_bs + tab['mag_error'] = np.sqrt(mag_inst_err ** 2 + zp_error_bs ** 2) + + fig, ax = plt.subplots(1, 1) + ax.errorbar(x, y, yerr=yerr, fmt='k.') + ax.scatter(x[sigma_clipped], y[sigma_clipped], marker='x', c='r') + ax.plot(x, best_fit(x), color='g', linewidth=3) + ax.set_xlabel('Catalog magnitude') + ax.set_ylabel('Instrumental magnitude') + fig.savefig(os.path.join(output_folder, 'calibration.png'), bbox_inches='tight') + plt.close(fig) + + # Check that we got valid photometry: + if not np.isfinite(tab[0]['mag']) or not np.isfinite(tab[0]['mag_error']): + raise RuntimeError("Target magnitude is undefined.") + + # ============================================================================================== + # SAVE PHOTOMETRY + # ============================================================================================== + + # Descriptions of columns: + tab['used_for_epsf'].description = 'Was object used for building ePSF?' + tab['mag'].description = 'Measured magnitude' + tab['mag'].unit = u.mag + tab['mag_error'].description = 'Error on measured magnitude' + tab['mag_error'].unit = u.mag + tab['flux_aperture'].description = 'Measured flux using aperture photometry' + tab['flux_aperture'].unit = u.count / u.second + tab['flux_aperture_error'].description = 'Error on measured flux using aperture photometry' + tab['flux_aperture_error'].unit = u.count / u.second + tab['flux_psf'].description = 'Measured flux using PSF photometry' + tab['flux_psf'].unit = u.count / u.second + tab['flux_psf_error'].description = 'Error on measured flux using PSF photometry' + tab['flux_psf_error'].unit = u.count / u.second + tab['pixel_column'].description = 'Location on image pixel columns' + tab['pixel_column'].unit = u.pixel + tab['pixel_row'].description = 'Location on image pixel rows' + tab['pixel_row'].unit = u.pixel + tab['pixel_column_psf_fit'].description = 'Measured location on image pixel columns from PSF photometry' + tab['pixel_column_psf_fit'].unit = u.pixel + tab[ + 'pixel_column_psf_fit_error'].description = 'Error on measured location on image pixel columns from PSF photometry' + tab['pixel_column_psf_fit_error'].unit = u.pixel + tab['pixel_row_psf_fit'].description = 'Measured location on image pixel rows from PSF photometry' + tab['pixel_row_psf_fit'].unit = u.pixel + tab['pixel_row_psf_fit_error'].description = 'Error on measured location on image pixel rows from PSF photometry' + tab['pixel_row_psf_fit_error'].unit = u.pixel + + # Meta-data: + tab.meta['fileid'] = fileid + tab.meta['target_name'] = target_name + tab.meta['version'] = __version__ + tab.meta['template'] = None if datafile.get('template') is None else datafile['template']['fileid'] + tab.meta['diffimg'] = None if datafile.get('diffimg') is None else datafile['diffimg']['fileid'] + tab.meta['photfilter'] = photfilter + tab.meta['fwhm'] = fwhm * u.pixel + tab.meta['pixel_scale'] = pixel_scale * u.arcsec / u.pixel + tab.meta['seeing'] = (fwhm * pixel_scale) * u.arcsec + tab.meta['obstime-bmjd'] = float(image.obstime.mjd) + tab.meta['zp'] = zp_bs + tab.meta['zp_error'] = zp_error_bs + tab.meta['zp_diff'] = np.abs(zp_bs - zp) + tab.meta['zp_error_weights'] = zp_error + tab.meta['head_wcs'] = head_wcs # TODO: Are these really useful? + tab.meta['used_wcs'] = used_wcs # TODO: Are these really useful? + + # Filepath where to save photometry: + photometry_output = os.path.join(output_folder, 'photometry.ecsv') + + # Write the final table to file: + tab.write(photometry_output, format='ascii.ecsv', delimiter=',', overwrite=True) + + toc = default_timer() + + logger.info("------------------------------------------------------") + logger.info("Success!") + logger.info("Main target: %f +/- %f", tab[0]['mag'], tab[0]['mag_error']) + logger.info("Photometry took: %.1f seconds", toc - tic) + + return photometry_output diff --git a/flows/plots.py b/flows/plots.py index b051441..c5c03b3 100644 --- a/flows/plots.py +++ b/flows/plots.py @@ -26,233 +26,230 @@ matplotlib.rcParams['text.usetex'] = False matplotlib.rcParams['mathtext.fontset'] = 'dejavuserif' -#-------------------------------------------------------------------------------------------------- + +# -------------------------------------------------------------------------------------------------- def plots_interactive(backend=('Qt5Agg', 'MacOSX', 'Qt4Agg', 'Qt5Cairo', 'TkAgg')): - """ - Change plotting to using an interactive backend. + """ + Change plotting to using an interactive backend. + + Parameters: + backend (str or list): Backend to change to. If not provided, will try different + interactive backends and use the first one that works. - Parameters: - backend (str or list): Backend to change to. If not provided, will try different - interactive backends and use the first one that works. + .. codeauthor:: Rasmus Handberg + """ - .. codeauthor:: Rasmus Handberg - """ + logger = logging.getLogger(__name__) + logger.debug("Valid interactive backends: %s", matplotlib.rcsetup.interactive_bk) - logger = logging.getLogger(__name__) - logger.debug("Valid interactive backends: %s", matplotlib.rcsetup.interactive_bk) + if isinstance(backend, str): + backend = [backend] - if isinstance(backend, str): - backend = [backend] + for bckend in backend: + if bckend not in matplotlib.rcsetup.interactive_bk: + logger.warning("Interactive backend '%s' is not found", bckend) + continue - for bckend in backend: - if bckend not in matplotlib.rcsetup.interactive_bk: - logger.warning("Interactive backend '%s' is not found", bckend) - continue + # Try to change the backend, and catch errors + # it it didn't work: + try: + plt.switch_backend(bckend) + except (ModuleNotFoundError, ImportError): + pass + else: + break - # Try to change the backend, and catch errors - # it it didn't work: - try: - plt.switch_backend(bckend) - except (ModuleNotFoundError, ImportError): - pass - else: - break -#-------------------------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------------------------- def plots_noninteractive(): - """ - Change plotting to using a non-interactive backend, which can e.g. be used on a cluster. - Will set backend to 'Agg'. - - .. codeauthor:: Rasmus Handberg - """ - plt.switch_backend('Agg') - -#-------------------------------------------------------------------------------------------------- -def plot_image(image, ax=None, scale='log', cmap=None, origin='lower', xlabel=None, - ylabel=None, cbar=None, clabel='Flux ($e^{-}s^{-1}$)', cbar_ticks=None, - cbar_ticklabels=None, cbar_pad=None, cbar_size='4%', title=None, - percentile=95.0, vmin=None, vmax=None, offset_axes=None, color_bad='k', **kwargs): - """ - Utility function to plot a 2D image. - - Parameters: - image (2d array): Image data. - ax (matplotlib.pyplot.axes, optional): Axes in which to plot. - Default (None) is to use current active axes. - scale (str or :py:class:`astropy.visualization.ImageNormalize` object, optional): - Normalization used to stretch the colormap. - Options: ``'linear'``, ``'sqrt'``, ``'log'``, ``'asinh'``, ``'histeq'``, ``'sinh'`` - and ``'squared'``. - Can also be a :py:class:`astropy.visualization.ImageNormalize` object. - Default is ``'log'``. - origin (str, optional): The origin of the coordinate system. - xlabel (str, optional): Label for the x-axis. - ylabel (str, optional): Label for the y-axis. - cbar (string, optional): Location of color bar. - Choises are ``'right'``, ``'left'``, ``'top'``, ``'bottom'``. - Default is not to create colorbar. - clabel (str, optional): Label for the color bar. - cbar_size (float, optional): Fractional size of colorbar compared to axes. Default='4%'. - cbar_pad (float, optional): Padding between axes and colorbar. - title (str or None, optional): Title for the plot. - percentile (float, optional): The fraction of pixels to keep in color-trim. - If single float given, the same fraction of pixels is eliminated from both ends. - If tuple of two floats is given, the two are used as the percentiles. - Default=95. - cmap (matplotlib colormap, optional): Colormap to use. Default is the ``Blues`` colormap. - vmin (float, optional): Lower limit to use for colormap. - vmax (float, optional): Upper limit to use for colormap. - color_bad (str, optional): Color to apply to bad pixels (NaN). Default is black. - kwargs (dict, optional): Keyword arguments to be passed to :py:func:`matplotlib.pyplot.imshow`. - - Returns: - :py:class:`matplotlib.image.AxesImage`: Image from returned - by :py:func:`matplotlib.pyplot.imshow`. - - .. codeauthor:: Rasmus Handberg - """ - - logger = logging.getLogger(__name__) - - # Backward compatible settings: - make_cbar = kwargs.pop('make_cbar', None) - if make_cbar: - raise FutureWarning("'make_cbar' is deprecated. Use 'cbar' instead.") - if not cbar: - cbar = make_cbar - - # Special treatment for boolean arrays: - if isinstance(image, np.ndarray) and image.dtype == 'bool': - if vmin is None: vmin = 0 - if vmax is None: vmax = 1 - if cbar_ticks is None: cbar_ticks = [0, 1] - if cbar_ticklabels is None: cbar_ticklabels = ['False', 'True'] - - # Calculate limits of color scaling: - interval = None - if vmin is None or vmax is None: - if allnan(image): - logger.warning("Image is all NaN") - vmin = 0 - vmax = 1 - if cbar_ticks is None: - cbar_ticks = [] - if cbar_ticklabels is None: - cbar_ticklabels = [] - elif isinstance(percentile, (list, tuple, np.ndarray)): - interval = viz.AsymmetricPercentileInterval(percentile[0], percentile[1]) - else: - interval = viz.PercentileInterval(percentile) - - # Create ImageNormalize object with extracted limits: - if scale in ('log', 'linear', 'sqrt', 'asinh', 'histeq', 'sinh', 'squared'): - if scale == 'log': - stretch = viz.LogStretch() - elif scale == 'linear': - stretch = viz.LinearStretch() - elif scale == 'sqrt': - stretch = viz.SqrtStretch() - elif scale == 'asinh': - stretch = viz.AsinhStretch() - elif scale == 'histeq': - stretch = viz.HistEqStretch(image[np.isfinite(image)]) - elif scale == 'sinh': - stretch = viz.SinhStretch() - elif scale == 'squared': - stretch = viz.SquaredStretch() - - # Create ImageNormalize object. Very important to use clip=False here, otherwise - # NaN points will not be plotted correctly. - norm = viz.ImageNormalize( - data=image, - interval=interval, - vmin=vmin, - vmax=vmax, - stretch=stretch, - clip=False) - - elif isinstance(scale, (viz.ImageNormalize, matplotlib.colors.Normalize)): - norm = scale - else: - raise ValueError("scale {} is not available.".format(scale)) - - if offset_axes: - extent = (offset_axes[0]-0.5, offset_axes[0] + image.shape[1]-0.5, offset_axes[1]-0.5, offset_axes[1] + image.shape[0]-0.5) - else: - extent = (-0.5, image.shape[1]-0.5, -0.5, image.shape[0]-0.5) - - if ax is None: - ax = plt.gca() - - # Set up the colormap to use. If a bad color is defined, - # add it to the colormap: - if cmap is None: - cmap = copy.copy(plt.get_cmap('Blues')) - elif isinstance(cmap, str): - cmap = copy.copy(plt.get_cmap(cmap)) - - if color_bad: - cmap.set_bad(color_bad, 1.0) - - im = ax.imshow(image, cmap=cmap, norm=norm, origin=origin, extent=extent, interpolation='nearest', **kwargs) - if xlabel is not None: - ax.set_xlabel(xlabel) - if ylabel is not None: - ax.set_ylabel(ylabel) - if title is not None: - ax.set_title(title) - ax.set_xlim([extent[0], extent[1]]) - ax.set_ylim([extent[2], extent[3]]) - - if cbar: - fig = ax.figure - divider = make_axes_locatable(ax) - if cbar == 'top': - cbar_pad = 0.05 if cbar_pad is None else cbar_pad - cax = divider.append_axes('top', size=cbar_size, pad=cbar_pad) - orientation = 'horizontal' - elif cbar == 'bottom': - cbar_pad = 0.35 if cbar_pad is None else cbar_pad - cax = divider.append_axes('bottom', size=cbar_size, pad=cbar_pad) - orientation = 'horizontal' - elif cbar == 'left': - cbar_pad = 0.35 if cbar_pad is None else cbar_pad - cax = divider.append_axes('left', size=cbar_size, pad=cbar_pad) - orientation = 'vertical' - else: - cbar_pad = 0.05 if cbar_pad is None else cbar_pad - cax = divider.append_axes('right', size=cbar_size, pad=cbar_pad) - orientation = 'vertical' - - cb = fig.colorbar(im, cax=cax, orientation=orientation) - - if cbar == 'top': - cax.xaxis.set_ticks_position('top') - cax.xaxis.set_label_position('top') - elif cbar == 'left': - cax.yaxis.set_ticks_position('left') - cax.yaxis.set_label_position('left') - - if clabel is not None: - cb.set_label(clabel) - if cbar_ticks is not None: - cb.set_ticks(cbar_ticks) - if cbar_ticklabels is not None: - cb.set_ticklabels(cbar_ticklabels) - - #cax.yaxis.set_major_locator(matplotlib.ticker.AutoLocator()) - #cax.yaxis.set_minor_locator(matplotlib.ticker.AutoLocator()) - cax.tick_params(which='both', direction='out', pad=5) - - # Settings for ticks: - integer_locator = MaxNLocator(nbins=10, integer=True) - ax.xaxis.set_major_locator(integer_locator) - ax.xaxis.set_minor_locator(integer_locator) - ax.yaxis.set_major_locator(integer_locator) - ax.yaxis.set_minor_locator(integer_locator) - ax.tick_params(which='both', direction='out', pad=5) - ax.xaxis.tick_bottom() - ax.yaxis.tick_left() - - return im + """ + Change plotting to using a non-interactive backend, which can e.g. be used on a cluster. + Will set backend to 'Agg'. + + .. codeauthor:: Rasmus Handberg + """ + plt.switch_backend('Agg') + + +# -------------------------------------------------------------------------------------------------- +def plot_image(image, ax=None, scale='log', cmap=None, origin='lower', xlabel=None, ylabel=None, cbar=None, + clabel='Flux ($e^{-}s^{-1}$)', cbar_ticks=None, cbar_ticklabels=None, cbar_pad=None, cbar_size='4%', + title=None, percentile=95.0, vmin=None, vmax=None, offset_axes=None, color_bad='k', **kwargs): + """ + Utility function to plot a 2D image. + + Parameters: + image (2d array): Image data. + ax (matplotlib.pyplot.axes, optional): Axes in which to plot. + Default (None) is to use current active axes. + scale (str or :py:class:`astropy.visualization.ImageNormalize` object, optional): + Normalization used to stretch the colormap. + Options: ``'linear'``, ``'sqrt'``, ``'log'``, ``'asinh'``, ``'histeq'``, ``'sinh'`` + and ``'squared'``. + Can also be a :py:class:`astropy.visualization.ImageNormalize` object. + Default is ``'log'``. + origin (str, optional): The origin of the coordinate system. + xlabel (str, optional): Label for the x-axis. + ylabel (str, optional): Label for the y-axis. + cbar (string, optional): Location of color bar. + Choises are ``'right'``, ``'left'``, ``'top'``, ``'bottom'``. + Default is not to create colorbar. + clabel (str, optional): Label for the color bar. + cbar_size (float, optional): Fractional size of colorbar compared to axes. Default='4%'. + cbar_pad (float, optional): Padding between axes and colorbar. + title (str or None, optional): Title for the plot. + percentile (float, optional): The fraction of pixels to keep in color-trim. + If single float given, the same fraction of pixels is eliminated from both ends. + If tuple of two floats is given, the two are used as the percentiles. + Default=95. + cmap (matplotlib colormap, optional): Colormap to use. Default is the ``Blues`` colormap. + vmin (float, optional): Lower limit to use for colormap. + vmax (float, optional): Upper limit to use for colormap. + color_bad (str, optional): Color to apply to bad pixels (NaN). Default is black. + kwargs (dict, optional): Keyword arguments to be passed to :py:func:`matplotlib.pyplot.imshow`. + + Returns: + :py:class:`matplotlib.image.AxesImage`: Image from returned + by :py:func:`matplotlib.pyplot.imshow`. + + .. codeauthor:: Rasmus Handberg + """ + + logger = logging.getLogger(__name__) + + # Backward compatible settings: + make_cbar = kwargs.pop('make_cbar', None) + if make_cbar: + raise FutureWarning("'make_cbar' is deprecated. Use 'cbar' instead.") + if not cbar: + cbar = make_cbar + + # Special treatment for boolean arrays: + if isinstance(image, np.ndarray) and image.dtype == 'bool': + if vmin is None: vmin = 0 + if vmax is None: vmax = 1 + if cbar_ticks is None: cbar_ticks = [0, 1] + if cbar_ticklabels is None: cbar_ticklabels = ['False', 'True'] + + # Calculate limits of color scaling: + interval = None + if vmin is None or vmax is None: + if allnan(image): + logger.warning("Image is all NaN") + vmin = 0 + vmax = 1 + if cbar_ticks is None: + cbar_ticks = [] + if cbar_ticklabels is None: + cbar_ticklabels = [] + elif isinstance(percentile, (list, tuple, np.ndarray)): + interval = viz.AsymmetricPercentileInterval(percentile[0], percentile[1]) + else: + interval = viz.PercentileInterval(percentile) + + # Create ImageNormalize object with extracted limits: + if scale in ('log', 'linear', 'sqrt', 'asinh', 'histeq', 'sinh', 'squared'): + if scale == 'log': + stretch = viz.LogStretch() + elif scale == 'linear': + stretch = viz.LinearStretch() + elif scale == 'sqrt': + stretch = viz.SqrtStretch() + elif scale == 'asinh': + stretch = viz.AsinhStretch() + elif scale == 'histeq': + stretch = viz.HistEqStretch(image[np.isfinite(image)]) + elif scale == 'sinh': + stretch = viz.SinhStretch() + elif scale == 'squared': + stretch = viz.SquaredStretch() + + # Create ImageNormalize object. Very important to use clip=False here, otherwise + # NaN points will not be plotted correctly. + norm = viz.ImageNormalize(data=image, interval=interval, vmin=vmin, vmax=vmax, stretch=stretch, clip=False) + + elif isinstance(scale, (viz.ImageNormalize, matplotlib.colors.Normalize)): + norm = scale + else: + raise ValueError("scale {} is not available.".format(scale)) + + if offset_axes: + extent = (offset_axes[0] - 0.5, offset_axes[0] + image.shape[1] - 0.5, offset_axes[1] - 0.5, + offset_axes[1] + image.shape[0] - 0.5) + else: + extent = (-0.5, image.shape[1] - 0.5, -0.5, image.shape[0] - 0.5) + + if ax is None: + ax = plt.gca() + + # Set up the colormap to use. If a bad color is defined, + # add it to the colormap: + if cmap is None: + cmap = copy.copy(plt.get_cmap('Blues')) + elif isinstance(cmap, str): + cmap = copy.copy(plt.get_cmap(cmap)) + + if color_bad: + cmap.set_bad(color_bad, 1.0) + + im = ax.imshow(image, cmap=cmap, norm=norm, origin=origin, extent=extent, interpolation='nearest', **kwargs) + if xlabel is not None: + ax.set_xlabel(xlabel) + if ylabel is not None: + ax.set_ylabel(ylabel) + if title is not None: + ax.set_title(title) + ax.set_xlim([extent[0], extent[1]]) + ax.set_ylim([extent[2], extent[3]]) + + if cbar: + fig = ax.figure + divider = make_axes_locatable(ax) + if cbar == 'top': + cbar_pad = 0.05 if cbar_pad is None else cbar_pad + cax = divider.append_axes('top', size=cbar_size, pad=cbar_pad) + orientation = 'horizontal' + elif cbar == 'bottom': + cbar_pad = 0.35 if cbar_pad is None else cbar_pad + cax = divider.append_axes('bottom', size=cbar_size, pad=cbar_pad) + orientation = 'horizontal' + elif cbar == 'left': + cbar_pad = 0.35 if cbar_pad is None else cbar_pad + cax = divider.append_axes('left', size=cbar_size, pad=cbar_pad) + orientation = 'vertical' + else: + cbar_pad = 0.05 if cbar_pad is None else cbar_pad + cax = divider.append_axes('right', size=cbar_size, pad=cbar_pad) + orientation = 'vertical' + + cb = fig.colorbar(im, cax=cax, orientation=orientation) + + if cbar == 'top': + cax.xaxis.set_ticks_position('top') + cax.xaxis.set_label_position('top') + elif cbar == 'left': + cax.yaxis.set_ticks_position('left') + cax.yaxis.set_label_position('left') + + if clabel is not None: + cb.set_label(clabel) + if cbar_ticks is not None: + cb.set_ticks(cbar_ticks) + if cbar_ticklabels is not None: + cb.set_ticklabels(cbar_ticklabels) + + # cax.yaxis.set_major_locator(matplotlib.ticker.AutoLocator()) + # cax.yaxis.set_minor_locator(matplotlib.ticker.AutoLocator()) + cax.tick_params(which='both', direction='out', pad=5) + + # Settings for ticks: + integer_locator = MaxNLocator(nbins=10, integer=True) + ax.xaxis.set_major_locator(integer_locator) + ax.xaxis.set_minor_locator(integer_locator) + ax.yaxis.set_major_locator(integer_locator) + ax.yaxis.set_minor_locator(integer_locator) + ax.tick_params(which='both', direction='out', pad=5) + ax.xaxis.tick_bottom() + ax.yaxis.tick_left() + + return im diff --git a/flows/reference_cleaning.py b/flows/reference_cleaning.py index 5d2c689..27daac5 100644 --- a/flows/reference_cleaning.py +++ b/flows/reference_cleaning.py @@ -18,313 +18,300 @@ from scipy.spatial import KDTree import pandas as pd # TODO: Convert to pure numpy implementation -#-------------------------------------------------------------------------------------------------- + +# -------------------------------------------------------------------------------------------------- class MinStarError(RuntimeError): - pass - -#-------------------------------------------------------------------------------------------------- -def force_reject_g2d(xarray, yarray, image, get_fwhm=True, rsq_min=0.5, radius=10, fwhm_guess=6.0, - fwhm_min=3.5, fwhm_max=18.0): - """ - - Parameters: - xarray: - yarray: - image: - get_fwhm (bool, optional): - rsq_min (float, optional): - radius (float, optional): - fwhm_guess=6.0: - fwhm_min=3.5: - fwhm_max=18.0: - - Returns: - tuple: - - masked_xys: - - mask: - - masked_rsqs: - - .. codeauthor:: Emir Karamehmetoglu - .. codeauthor:: Rasmus Handberg - """ - # Set up 2D Gaussian model for fitting to reference stars: - g2d = models.Gaussian2D(amplitude=1.0, - x_mean=radius, - y_mean=radius, - x_stddev=fwhm_guess * gaussian_fwhm_to_sigma) - g2d.amplitude.bounds = (0.1, 2.0) - g2d.x_mean.bounds = (0.5 * radius, 1.5 * radius) - g2d.y_mean.bounds = (0.5 * radius, 1.5 * radius) - g2d.x_stddev.bounds = ( - fwhm_min * gaussian_fwhm_to_sigma, - fwhm_max * gaussian_fwhm_to_sigma - ) - g2d.y_stddev.tied = lambda model: model.x_stddev - g2d.theta.fixed = True - - gfitter = fitting.LevMarLSQFitter() - - # Stars reject - N = len(xarray) - fwhms = np.full((N, 2), np.NaN) - xys = np.full((N, 2), np.NaN) - rsqs = np.full(N, np.NaN) - for i, (x, y) in enumerate(zip(xarray, yarray)): - x = int(np.round(x)) - y = int(np.round(y)) - xmin = max(x - radius, 0) - xmax = min(x + radius + 1, image.shape[1]) - ymin = max(y - radius, 0) - ymax = min(y + radius + 1, image.shape[0]) - - curr_star = deepcopy(image.subclean[ymin:ymax, xmin:xmax]) - - edge = np.zeros_like(curr_star, dtype='bool') - edge[(0, -1), :] = True - edge[:, (0, -1)] = True - curr_star -= nanmedian(curr_star[edge]) - curr_star /= np.nanmax(curr_star) - - ypos, xpos = np.mgrid[:curr_star.shape[0], :curr_star.shape[1]] - gfit = gfitter(g2d, x=xpos, y=ypos, z=curr_star) - - # Center - xys[i] = np.array([gfit.x_mean + x - radius, gfit.y_mean + y - radius], dtype='float64') - - # Calculate rsq - sstot = nansum((curr_star - nanmean(curr_star))**2) - sserr = nansum(gfitter.fit_info['fvec']**2) - rsqs[i] = 0 if sstot == 0 else 1.0 - (sserr / sstot) - - # FWHM - fwhms[i] = gfit.x_fwhm - - masked_xys = np.ma.masked_array(xys, ~np.isfinite(xys)) - masked_rsqs = np.ma.masked_array(rsqs, ~np.isfinite(rsqs)) - mask = (masked_rsqs >= rsq_min) & (masked_rsqs < 1.0) # Reject Rsq < rsq_min - # changed - #masked_xys = masked_xys[mask] # Clean extracted array. - # to - masked_xys.mask[~mask] = True - # don't know if it breaks anything, but it doesn't make sence if - # len(masked_xys) != len(masked_rsqs) FIXME - masked_fwhms = np.ma.masked_array(fwhms, ~np.isfinite(fwhms)) - - if get_fwhm: - return masked_fwhms, masked_xys, mask, masked_rsqs - return masked_xys, mask, masked_rsqs - -#-------------------------------------------------------------------------------------------------- -def clean_with_rsq_and_get_fwhm(masked_fwhms, masked_rsqs, references, min_fwhm_references=2, - min_references=6, rsq_min=0.15): - """ - Clean references and obtain fwhm using RSQ values. - - Parameters: - masked_fwhms (np.ma.maskedarray): array of fwhms - masked_rsqs (np.ma.maskedarray): array of rsq values - references (astropy.table.Table): table or reference stars - min_fwhm_references: (Default 2) min stars to get a fwhm - min_references: (Default 6) min stars to aim for when cutting by R2 - rsq_min: (Default 0.15) min rsq value - - .. codeauthor:: Emir Karamehmetoglu - """ - min_references_now = min_references - rsqvals = np.arange(rsq_min, 0.95, 0.15)[::-1] - fwhm_found = False - min_references_achieved = False - - # Clean based on R^2 Value - while not min_references_achieved: - for rsqval in rsqvals: - mask = (masked_rsqs >= rsqval) & (masked_rsqs < 1.0) - nreferences = np.sum(np.isfinite(masked_fwhms[mask])) - if nreferences >= min_fwhm_references: - _fwhms_cut_ = np.nanmean(sigma_clip(masked_fwhms[mask], maxiters=100, sigma=2.0)) - if not fwhm_found: - fwhm = _fwhms_cut_ - fwhm_found = True - if nreferences >= min_references_now: - references = references[mask] - min_references_achieved = True - break - if min_references_achieved: - break - min_references_now = min_references_now - 2 - if (min_references_now < 2) and fwhm_found: - break - elif not fwhm_found: - raise RuntimeError("Could not estimate FWHM") - - if np.isnan(fwhm): - raise RuntimeError("Could not estimate FWHM") - - # if minimum references not found, then take what we can get with even a weaker cut. - # TODO: Is this right, or should we grab rsq_min (or even weaker?) - min_references_now = min_references - 2 - while not min_references_achieved: - mask = (masked_rsqs >= rsq_min) & (masked_rsqs < 1.0) - nreferences = np.sum(np.isfinite(masked_fwhms[mask])) - if nreferences >= min_references_now: - references = references[mask] - min_references_achieved = True - rsq_min = rsq_min - 0.07 - min_references_now = min_references_now - 1 - - # Check len of references as this is a destructive cleaning. - # if len(references) == 2: logger.info('2 reference stars remaining, check WCS and image quality') - if len(references) < 2: - raise RuntimeError(f"{len(references)} References remaining; could not clean.") - return fwhm, references - -#-------------------------------------------------------------------------------------------------- + pass + + +# -------------------------------------------------------------------------------------------------- +def force_reject_g2d(xarray, yarray, image, get_fwhm=True, rsq_min=0.5, radius=10, fwhm_guess=6.0, fwhm_min=3.5, + fwhm_max=18.0): + """ + + Parameters: + xarray: + yarray: + image: + get_fwhm (bool, optional): + rsq_min (float, optional): + radius (float, optional): + fwhm_guess=6.0: + fwhm_min=3.5: + fwhm_max=18.0: + + Returns: + tuple: + - masked_xys: + - mask: + - masked_rsqs: + + .. codeauthor:: Emir Karamehmetoglu + .. codeauthor:: Rasmus Handberg + """ + # Set up 2D Gaussian model for fitting to reference stars: + g2d = models.Gaussian2D(amplitude=1.0, x_mean=radius, y_mean=radius, x_stddev=fwhm_guess * gaussian_fwhm_to_sigma) + g2d.amplitude.bounds = (0.1, 2.0) + g2d.x_mean.bounds = (0.5 * radius, 1.5 * radius) + g2d.y_mean.bounds = (0.5 * radius, 1.5 * radius) + g2d.x_stddev.bounds = (fwhm_min * gaussian_fwhm_to_sigma, fwhm_max * gaussian_fwhm_to_sigma) + g2d.y_stddev.tied = lambda model: model.x_stddev + g2d.theta.fixed = True + + gfitter = fitting.LevMarLSQFitter() + + # Stars reject + N = len(xarray) + fwhms = np.full((N, 2), np.NaN) + xys = np.full((N, 2), np.NaN) + rsqs = np.full(N, np.NaN) + for i, (x, y) in enumerate(zip(xarray, yarray)): + x = int(np.round(x)) + y = int(np.round(y)) + xmin = max(x - radius, 0) + xmax = min(x + radius + 1, image.shape[1]) + ymin = max(y - radius, 0) + ymax = min(y + radius + 1, image.shape[0]) + + curr_star = deepcopy(image.subclean[ymin:ymax, xmin:xmax]) + + edge = np.zeros_like(curr_star, dtype='bool') + edge[(0, -1), :] = True + edge[:, (0, -1)] = True + curr_star -= nanmedian(curr_star[edge]) + curr_star /= np.nanmax(curr_star) + + ypos, xpos = np.mgrid[:curr_star.shape[0], :curr_star.shape[1]] + gfit = gfitter(g2d, x=xpos, y=ypos, z=curr_star) + + # Center + xys[i] = np.array([gfit.x_mean + x - radius, gfit.y_mean + y - radius], dtype='float64') + + # Calculate rsq + sstot = nansum((curr_star - nanmean(curr_star)) ** 2) + sserr = nansum(gfitter.fit_info['fvec'] ** 2) + rsqs[i] = 0 if sstot == 0 else 1.0 - (sserr / sstot) + + # FWHM + fwhms[i] = gfit.x_fwhm + + masked_xys = np.ma.masked_array(xys, ~np.isfinite(xys)) + masked_rsqs = np.ma.masked_array(rsqs, ~np.isfinite(rsqs)) + mask = (masked_rsqs >= rsq_min) & (masked_rsqs < 1.0) # Reject Rsq < rsq_min + # changed + # masked_xys = masked_xys[mask] # Clean extracted array. + # to + masked_xys.mask[~mask] = True + # don't know if it breaks anything, but it doesn't make sence if + # len(masked_xys) != len(masked_rsqs) FIXME + masked_fwhms = np.ma.masked_array(fwhms, ~np.isfinite(fwhms)) + + if get_fwhm: + return masked_fwhms, masked_xys, mask, masked_rsqs + return masked_xys, mask, masked_rsqs + + +# -------------------------------------------------------------------------------------------------- +def clean_with_rsq_and_get_fwhm(masked_fwhms, masked_rsqs, references, min_fwhm_references=2, min_references=6, + rsq_min=0.15): + """ + Clean references and obtain fwhm using RSQ values. + + Parameters: + masked_fwhms (np.ma.maskedarray): array of fwhms + masked_rsqs (np.ma.maskedarray): array of rsq values + references (astropy.table.Table): table or reference stars + min_fwhm_references: (Default 2) min stars to get a fwhm + min_references: (Default 6) min stars to aim for when cutting by R2 + rsq_min: (Default 0.15) min rsq value + + .. codeauthor:: Emir Karamehmetoglu + """ + min_references_now = min_references + rsqvals = np.arange(rsq_min, 0.95, 0.15)[::-1] + fwhm_found = False + min_references_achieved = False + + # Clean based on R^2 Value + while not min_references_achieved: + for rsqval in rsqvals: + mask = (masked_rsqs >= rsqval) & (masked_rsqs < 1.0) + nreferences = np.sum(np.isfinite(masked_fwhms[mask])) + if nreferences >= min_fwhm_references: + _fwhms_cut_ = np.nanmean(sigma_clip(masked_fwhms[mask], maxiters=100, sigma=2.0)) + if not fwhm_found: + fwhm = _fwhms_cut_ + fwhm_found = True + if nreferences >= min_references_now: + references = references[mask] + min_references_achieved = True + break + if min_references_achieved: + break + min_references_now = min_references_now - 2 + if (min_references_now < 2) and fwhm_found: + break + elif not fwhm_found: + raise RuntimeError("Could not estimate FWHM") + + if np.isnan(fwhm): + raise RuntimeError("Could not estimate FWHM") + + # if minimum references not found, then take what we can get with even a weaker cut. + # TODO: Is this right, or should we grab rsq_min (or even weaker?) + min_references_now = min_references - 2 + while not min_references_achieved: + mask = (masked_rsqs >= rsq_min) & (masked_rsqs < 1.0) + nreferences = np.sum(np.isfinite(masked_fwhms[mask])) + if nreferences >= min_references_now: + references = references[mask] + min_references_achieved = True + rsq_min = rsq_min - 0.07 + min_references_now = min_references_now - 1 + + # Check len of references as this is a destructive cleaning. + # if len(references) == 2: logger.info('2 reference stars remaining, check WCS and image quality') + if len(references) < 2: + raise RuntimeError(f"{len(references)} References remaining; could not clean.") + return fwhm, references + + +# -------------------------------------------------------------------------------------------------- def mkposxy(posx, posy): - '''Make 2D np array for astroalign''' - img_posxy = np.array([[x, y] for x, y in zip(posx, posy)], dtype="float64") - return img_posxy + '''Make 2D np array for astroalign''' + img_posxy = np.array([[x, y] for x, y in zip(posx, posy)], dtype="float64") + return img_posxy + -#-------------------------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------------------------- def try_transform(source, target, pixeltol=2, nnearest=5, max_stars=50): - aa.NUM_NEAREST_NEIGHBORS = nnearest - aa.PIXEL_TOL = pixeltol - transform, (sourcestars, targetstars) = aa.find_transform( - source, - target, - max_control_points=max_stars) - return sourcestars, targetstars - -#-------------------------------------------------------------------------------------------------- + aa.NUM_NEAREST_NEIGHBORS = nnearest + aa.PIXEL_TOL = pixeltol + transform, (sourcestars, targetstars) = aa.find_transform(source, target, max_control_points=max_stars) + return sourcestars, targetstars + + +# -------------------------------------------------------------------------------------------------- def try_astroalign(source, target, pixeltol=2, nnearest=5, max_stars_n=50): - # Get indexes of matched stars - success = False - try: - source_stars, target_stars = try_transform( - source, - target, - pixeltol=pixeltol, - nnearest=nnearest, - max_stars=max_stars_n) - source_ind = np.argwhere(np.in1d(source, source_stars)[::2]).flatten() - target_ind = np.argwhere(np.in1d(target, target_stars)[::2]).flatten() - success = True - except aa.MaxIterError: - source_ind, target_ind = 'None', 'None' - return source_ind, target_ind, success - -#-------------------------------------------------------------------------------------------------- -def min_to_max_astroalign(source, target, fwhm=5, fwhm_min=1, fwhm_max=4, knn_min=5, - knn_max=20, max_stars=100, min_matches=3): - """Try to find matches using astroalign asterisms by stepping through some parameters.""" - # Set max_control_points par based on number of stars and max_stars. - nstars = max(len(source), len(source)) - if max_stars >= nstars: - max_stars_list = 'None' - else: - if max_stars > 60: - max_stars_list = (max_stars, 50, 4, 3) - else: - max_stars_list = (max_stars, 6, 4, 3) - - # Create max_stars step-through list if not given - if max_stars_list == 'None': - if nstars > 6: - max_stars_list = (nstars, 5, 3) - elif nstars > 3: - max_stars_list = (nstars, 3) - - pixeltols = np.linspace(int(fwhm * fwhm_min), int(fwhm * fwhm_max), 4, dtype=int) - nearest_neighbors = np.linspace(knn_min, min(knn_max, nstars), 4, dtype=int) - - for max_stars_n in max_stars_list: - for pixeltol in pixeltols: - for nnearest in nearest_neighbors: - source_ind, target_ind, success = try_astroalign(source, target, - pixeltol=pixeltol, - nnearest=nnearest, - max_stars_n=max_stars_n) - if success: - if len(source_ind) >= min_matches: - return source_ind, target_ind, success - else: - success = False - return 'None', 'None', success - -#-------------------------------------------------------------------------------------------------- + # Get indexes of matched stars + success = False + try: + source_stars, target_stars = try_transform(source, target, pixeltol=pixeltol, nnearest=nnearest, + max_stars=max_stars_n) + source_ind = np.argwhere(np.in1d(source, source_stars)[::2]).flatten() + target_ind = np.argwhere(np.in1d(target, target_stars)[::2]).flatten() + success = True + except aa.MaxIterError: + source_ind, target_ind = 'None', 'None' + return source_ind, target_ind, success + + +# -------------------------------------------------------------------------------------------------- +def min_to_max_astroalign(source, target, fwhm=5, fwhm_min=1, fwhm_max=4, knn_min=5, knn_max=20, max_stars=100, + min_matches=3): + """Try to find matches using astroalign asterisms by stepping through some parameters.""" + # Set max_control_points par based on number of stars and max_stars. + nstars = max(len(source), len(source)) + if max_stars >= nstars: + max_stars_list = 'None' + else: + if max_stars > 60: + max_stars_list = (max_stars, 50, 4, 3) + else: + max_stars_list = (max_stars, 6, 4, 3) + + # Create max_stars step-through list if not given + if max_stars_list == 'None': + if nstars > 6: + max_stars_list = (nstars, 5, 3) + elif nstars > 3: + max_stars_list = (nstars, 3) + + pixeltols = np.linspace(int(fwhm * fwhm_min), int(fwhm * fwhm_max), 4, dtype=int) + nearest_neighbors = np.linspace(knn_min, min(knn_max, nstars), 4, dtype=int) + + for max_stars_n in max_stars_list: + for pixeltol in pixeltols: + for nnearest in nearest_neighbors: + source_ind, target_ind, success = try_astroalign(source, target, pixeltol=pixeltol, nnearest=nnearest, + max_stars_n=max_stars_n) + if success: + if len(source_ind) >= min_matches: + return source_ind, target_ind, success + else: + success = False + return 'None', 'None', success + + +# -------------------------------------------------------------------------------------------------- def kdtree(source, target, fwhm=5, fwhm_max=4, min_matches=3): - '''Use KDTree to get nearest neighbor matches within fwhm_max*fwhm distance''' - - # Use KDTree to rapidly efficiently query nearest neighbors - - tt = KDTree(target) - st = KDTree(source) - matches_list = st.query_ball_tree(tt, r=fwhm * fwhm_max) - - #indx = [] - targets = [] - sources = [] - for j, (sstar, match) in enumerate(zip(source, matches_list)): - if np.array(target[match]).size != 0: - targets.append(match[0]) - sources.append(j) - sources = np.array(sources, dtype=int) - targets = np.array(targets, dtype=int) - - # Return indexes of matches - return sources, targets, len(sources) >= min_matches - -#-------------------------------------------------------------------------------------------------- -def get_new_wcs(extracted_ind, extracted_stars, clean_references, ref_ind, obstime, rakey='ra_obs', - deckey='decl_obs'): - - targets = (extracted_stars[extracted_ind][:, 0], extracted_stars[extracted_ind][:, 1]) - - c = SkyCoord( - ra=clean_references[rakey][ref_ind], - dec=clean_references[deckey][ref_ind], - frame='icrs', - obstime=obstime - ) - return wcs.utils.fit_wcs_from_points(targets, c) - -#-------------------------------------------------------------------------------------------------- -def get_clean_references(references, masked_rsqs, min_references_ideal=6, min_references_abs=3, - rsq_min=0.15, rsq_ideal=0.5, keep_max=100, rescue_bad: bool = True): - - # Greedy first try - mask = (masked_rsqs >= rsq_ideal) & (masked_rsqs < 1.0) - if np.sum(np.isfinite(masked_rsqs[mask])) >= min_references_ideal: - if len(references[mask]) <= keep_max: - return references[mask] - elif len(references[mask]) >= keep_max: - - df = pd.DataFrame(masked_rsqs, columns=['rsq']) - masked_rsqs.mask = ~mask - nmasked_rsqs = df.sort_values('rsq', ascending=False).dropna().index._data - return references[nmasked_rsqs[:keep_max]] - - # Desperate second try - mask = (masked_rsqs >= rsq_min) & (masked_rsqs < 1.0) - masked_rsqs.mask = ~mask - - # Switching to pandas for easier selection - df = pd.DataFrame(masked_rsqs, columns=['rsq']) - nmasked_rsqs = deepcopy( - df.sort_values('rsq', ascending=False).dropna().index._data) - nmasked_rsqs = nmasked_rsqs[:min(min_references_ideal, len(nmasked_rsqs))] - if len(nmasked_rsqs) >= min_references_abs: - return references[nmasked_rsqs] - if not rescue_bad: - raise MinStarError(f'Less than {min_references_abs} clean stars and rescue_bad = False') - - # Extremely desperate last ditch attempt i.e. "rescue bad" - mask = (masked_rsqs >= 0.02) & (masked_rsqs < 1.0) - masked_rsqs.mask = ~mask - - # Switch to pandas - df = pd.DataFrame(masked_rsqs, columns=['rsq']) - nmasked_rsqs = df.sort_values('rsq', ascending=False).dropna().index._data - nmasked_rsqs = nmasked_rsqs[:min(min_references_ideal, len(nmasked_rsqs))] - if len(nmasked_rsqs) < 2: - raise MinStarError('Less than 2 clean stars.') - return references[nmasked_rsqs] # Return if len >= 2 + '''Use KDTree to get nearest neighbor matches within fwhm_max*fwhm distance''' + + # Use KDTree to rapidly efficiently query nearest neighbors + + tt = KDTree(target) + st = KDTree(source) + matches_list = st.query_ball_tree(tt, r=fwhm * fwhm_max) + + # indx = [] + targets = [] + sources = [] + for j, (sstar, match) in enumerate(zip(source, matches_list)): + if np.array(target[match]).size != 0: + targets.append(match[0]) + sources.append(j) + sources = np.array(sources, dtype=int) + targets = np.array(targets, dtype=int) + + # Return indexes of matches + return sources, targets, len(sources) >= min_matches + + +# -------------------------------------------------------------------------------------------------- +def get_new_wcs(extracted_ind, extracted_stars, clean_references, ref_ind, obstime, rakey='ra_obs', deckey='decl_obs'): + targets = (extracted_stars[extracted_ind][:, 0], extracted_stars[extracted_ind][:, 1]) + + c = SkyCoord(ra=clean_references[rakey][ref_ind], dec=clean_references[deckey][ref_ind], frame='icrs', + obstime=obstime) + return wcs.utils.fit_wcs_from_points(targets, c) + + +# -------------------------------------------------------------------------------------------------- +def get_clean_references(references, masked_rsqs, min_references_ideal=6, min_references_abs=3, rsq_min=0.15, + rsq_ideal=0.5, keep_max=100, rescue_bad: bool = True): + # Greedy first try + mask = (masked_rsqs >= rsq_ideal) & (masked_rsqs < 1.0) + if np.sum(np.isfinite(masked_rsqs[mask])) >= min_references_ideal: + if len(references[mask]) <= keep_max: + return references[mask] + elif len(references[mask]) >= keep_max: + + df = pd.DataFrame(masked_rsqs, columns=['rsq']) + masked_rsqs.mask = ~mask + nmasked_rsqs = df.sort_values('rsq', ascending=False).dropna().index._data + return references[nmasked_rsqs[:keep_max]] + + # Desperate second try + mask = (masked_rsqs >= rsq_min) & (masked_rsqs < 1.0) + masked_rsqs.mask = ~mask + + # Switching to pandas for easier selection + df = pd.DataFrame(masked_rsqs, columns=['rsq']) + nmasked_rsqs = deepcopy(df.sort_values('rsq', ascending=False).dropna().index._data) + nmasked_rsqs = nmasked_rsqs[:min(min_references_ideal, len(nmasked_rsqs))] + if len(nmasked_rsqs) >= min_references_abs: + return references[nmasked_rsqs] + if not rescue_bad: + raise MinStarError(f'Less than {min_references_abs} clean stars and rescue_bad = False') + + # Extremely desperate last ditch attempt i.e. "rescue bad" + mask = (masked_rsqs >= 0.02) & (masked_rsqs < 1.0) + masked_rsqs.mask = ~mask + + # Switch to pandas + df = pd.DataFrame(masked_rsqs, columns=['rsq']) + nmasked_rsqs = df.sort_values('rsq', ascending=False).dropna().index._data + nmasked_rsqs = nmasked_rsqs[:min(min_references_ideal, len(nmasked_rsqs))] + if len(nmasked_rsqs) < 2: + raise MinStarError('Less than 2 clean stars.') + return references[nmasked_rsqs] # Return if len >= 2 diff --git a/flows/run_imagematch.py b/flows/run_imagematch.py index 5462905..1096291 100644 --- a/flows/run_imagematch.py +++ b/flows/run_imagematch.py @@ -17,18 +17,19 @@ import re from astropy.io import fits from astropy.wcs.utils import proj_plane_pixel_area -#from setuptools import Distribution -#from setuptools.command.install import install +# from setuptools import Distribution +# from setuptools.command.install import install from .load_image import load_image from . import api -#-------------------------------------------------------------------------------------------------- -#class OnlyGetScriptPath(install): + +# -------------------------------------------------------------------------------------------------- +# class OnlyGetScriptPath(install): # def run(self): # # does not call install.run() by design # self.distribution.install_scripts = self.install_scripts -#def get_setuptools_script_dir(): +# def get_setuptools_script_dir(): # dist = Distribution({'cmdclass': {'install': OnlyGetScriptPath}}) # dist.dry_run = True # not sure if necessary, but to be safe # dist.parse_config_files() @@ -37,152 +38,140 @@ # command.run() # return dist.install_scripts -#-------------------------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------------------------- def run_imagematch(datafile, target=None, star_coord=None, fwhm=None, pixel_scale=None): - """ - Run ImageMatch on a datafile. - - Parameters: - datafile (dict): Data file to run ImageMatch on. - target (:class:`astropy.table.Table`, optional): Target informaton. - - .. codeauthor:: Rasmus Handberg - """ - - logger = logging.getLogger(__name__) - - if datafile.get('template') is None: - raise ValueError("DATAFILE input does not specify a template to use.") - - # Extract paths to science and reference images: - reference_image = os.path.join(datafile['archive_path'], datafile['template']['path']) - science_image = os.path.join(datafile['archive_path'], datafile['path']) - - # If the target was not provided in the function call, - # use the API to get the target information: - if target is None: - catalog = api.get_catalog(datafile['targetid'], output='table') - target = catalog['target'][0] - - # Find the path to where the ImageMatch program is installed. - # This is to avoid problems with it not being on the users PATH - # and if the user is using some other version of the python executable. - # TODO: There must be a better way of doing this! - #imgmatch = os.path.join(get_setuptools_script_dir(), 'ImageMatch') - if os.name == "nt": - out = subprocess.check_output(["where", "ImageMatch"], universal_newlines=True) - imgmatch = out.strip() - else: - out = subprocess.check_output(["whereis", "ImageMatch"], universal_newlines=True) - out = re.match('ImageMatch: (.+)', out.strip()) - imgmatch = out.group(1) - - if not os.path.isfile(imgmatch): - raise FileNotFoundError("ImageMatch not found") - - # Find the ImageMatch config file to use based on the site of the observations: - __dir__ = os.path.dirname(os.path.abspath(__file__)) - if datafile['site'] in (1,3,4,6): - config_file = os.path.join(__dir__, 'imagematch', 'imagematch_lcogt.cfg') - elif datafile['site'] == 2: - config_file = os.path.join(__dir__, 'imagematch', 'imagematch_hawki.cfg') - elif datafile['site'] == 5: - config_file = os.path.join(__dir__, 'imagematch', 'imagematch_alfosc.cfg') - else: - config_file = os.path.join(__dir__, 'imagematch', 'imagematch_default.cfg') - if not os.path.isfile(config_file): - raise FileNotFoundError(config_file) - - if pixel_scale is None: - if datafile['site'] in (1,3,4,6): - # LCOGT provides the pixel scale directly in the header - pixel_scale = 'PIXSCALE' - else: - image = load_image(science_image) - pixel_area = proj_plane_pixel_area(image.wcs) - pixel_scale = np.sqrt(pixel_area)*3600 # arcsec/pixel - logger.info("Calculated science image pixel scale: %f", pixel_scale) - - if datafile['template']['site'] in (1,3,4,6): - # LCOGT provides the pixel scale directly in the header - mscale = 'PIXSCALE' - else: - template = load_image(reference_image) - template_pixel_area = proj_plane_pixel_area(template.wcs.celestial) - mscale = np.sqrt(template_pixel_area)*3600 # arcsec/pixel - logger.info("Calculated template pixel scale: %f", mscale) - - # Scale kernel radius with FWHM: - if fwhm is None: - kernel_radius = 9 - else: - kernel_radius = max(9, int(np.ceil(1.5*fwhm))) - if kernel_radius % 2 == 0: - kernel_radius += 1 - - # We will work in a temporary directory, since ImageMatch produces - # a lot of extra output files that we don't want to have lying around - # after it completes - with tempfile.TemporaryDirectory() as tmpdir: - - # Copy the science and reference image to the temp dir: - shutil.copy(reference_image, tmpdir) - shutil.copy(science_image, tmpdir) - - # Construct the command to run ImageMatch: - for match_threshold in (3.0, 5.0, 7.0, 10.0): - cmd = '"{python:s}" "{imgmatch:s}" -cfg "{config_file:s}" -snx {target_ra:.10f}d -sny {target_dec:.10f}d -p {kernel_radius:d} -o {order:d} -s {match:f} -scale {pixel_scale:} -mscale {mscale:} -m "{reference_image:s}" "{science_image:s}"'.format( - python=sys.executable, - imgmatch=imgmatch, - config_file=config_file, - reference_image=os.path.basename(reference_image), - science_image=os.path.basename(science_image), - target_ra=target['ra'], - target_dec=target['decl'], - match=match_threshold, - kernel_radius=kernel_radius, - pixel_scale=pixel_scale, - mscale=mscale, - order=1 - ) - logger.info("Executing command: %s", cmd) - - # Run the command in a subprocess: - cmd = shlex.split(cmd) - proc = subprocess.Popen(cmd, - cwd=tmpdir, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - universal_newlines=True) - stdout_data, stderr_data = proc.communicate() - returncode = proc.returncode - proc.kill() # Cleanup - Is this really needed? - - # Check the outputs from the subprocess: - logger.info("Return code: %d", returncode) - logger.info("STDOUT:\n%s", stdout_data.strip()) - if stderr_data.strip() != '': - logger.error("STDERR:\n%s", stderr_data.strip()) - if returncode < 0: - raise Exception("ImageMatch failed. Processed killed by OS with returncode %d." % returncode) - elif 'Failed object match... giving up.' in stdout_data: - #raise Exception("ImageMatch giving up matching objects") - continue - elif returncode > 0: - raise Exception("ImageMatch failed.") - - # Load the resulting difference image into memory: - diffimg_name = re.sub(r'\.fits(\.gz|\.bz2)?$', r'diff.fits\1', os.path.basename(science_image)) - diffimg_path = os.path.join(tmpdir, diffimg_name) - if not os.path.isfile(diffimg_path): - raise FileNotFoundError(diffimg_path) - - break - - else: - raise Exception("ImageMatch could not create difference image.") - - with fits.open(diffimg_path, mode='readonly') as hdu: - diffimg = np.asarray(hdu[0].data) - - return diffimg + """ + Run ImageMatch on a datafile. + + Parameters: + datafile (dict): Data file to run ImageMatch on. + target (:class:`astropy.table.Table`, optional): Target informaton. + + .. codeauthor:: Rasmus Handberg + """ + + logger = logging.getLogger(__name__) + + if datafile.get('template') is None: + raise ValueError("DATAFILE input does not specify a template to use.") + + # Extract paths to science and reference images: + reference_image = os.path.join(datafile['archive_path'], datafile['template']['path']) + science_image = os.path.join(datafile['archive_path'], datafile['path']) + + # If the target was not provided in the function call, + # use the API to get the target information: + if target is None: + catalog = api.get_catalog(datafile['targetid'], output='table') + target = catalog['target'][0] + + # Find the path to where the ImageMatch program is installed. + # This is to avoid problems with it not being on the users PATH + # and if the user is using some other version of the python executable. + # TODO: There must be a better way of doing this! + # imgmatch = os.path.join(get_setuptools_script_dir(), 'ImageMatch') + if os.name == "nt": + out = subprocess.check_output(["where", "ImageMatch"], universal_newlines=True) + imgmatch = out.strip() + else: + out = subprocess.check_output(["whereis", "ImageMatch"], universal_newlines=True) + out = re.match('ImageMatch: (.+)', out.strip()) + imgmatch = out.group(1) + + if not os.path.isfile(imgmatch): + raise FileNotFoundError("ImageMatch not found") + + # Find the ImageMatch config file to use based on the site of the observations: + __dir__ = os.path.dirname(os.path.abspath(__file__)) + if datafile['site'] in (1, 3, 4, 6): + config_file = os.path.join(__dir__, 'imagematch', 'imagematch_lcogt.cfg') + elif datafile['site'] == 2: + config_file = os.path.join(__dir__, 'imagematch', 'imagematch_hawki.cfg') + elif datafile['site'] == 5: + config_file = os.path.join(__dir__, 'imagematch', 'imagematch_alfosc.cfg') + else: + config_file = os.path.join(__dir__, 'imagematch', 'imagematch_default.cfg') + if not os.path.isfile(config_file): + raise FileNotFoundError(config_file) + + if pixel_scale is None: + if datafile['site'] in (1, 3, 4, 6): + # LCOGT provides the pixel scale directly in the header + pixel_scale = 'PIXSCALE' + else: + image = load_image(science_image) + pixel_area = proj_plane_pixel_area(image.wcs) + pixel_scale = np.sqrt(pixel_area) * 3600 # arcsec/pixel + logger.info("Calculated science image pixel scale: %f", pixel_scale) + + if datafile['template']['site'] in (1, 3, 4, 6): + # LCOGT provides the pixel scale directly in the header + mscale = 'PIXSCALE' + else: + template = load_image(reference_image) + template_pixel_area = proj_plane_pixel_area(template.wcs.celestial) + mscale = np.sqrt(template_pixel_area) * 3600 # arcsec/pixel + logger.info("Calculated template pixel scale: %f", mscale) + + # Scale kernel radius with FWHM: + if fwhm is None: + kernel_radius = 9 + else: + kernel_radius = max(9, int(np.ceil(1.5 * fwhm))) + if kernel_radius % 2 == 0: + kernel_radius += 1 + + # We will work in a temporary directory, since ImageMatch produces + # a lot of extra output files that we don't want to have lying around + # after it completes + with tempfile.TemporaryDirectory() as tmpdir: + + # Copy the science and reference image to the temp dir: + shutil.copy(reference_image, tmpdir) + shutil.copy(science_image, tmpdir) + + # Construct the command to run ImageMatch: + for match_threshold in (3.0, 5.0, 7.0, 10.0): + cmd = '"{python:s}" "{imgmatch:s}" -cfg "{config_file:s}" -snx {target_ra:.10f}d -sny {target_dec:.10f}d -p {kernel_radius:d} -o {order:d} -s {match:f} -scale {pixel_scale:} -mscale {mscale:} -m "{reference_image:s}" "{science_image:s}"'.format( + python=sys.executable, imgmatch=imgmatch, config_file=config_file, + reference_image=os.path.basename(reference_image), science_image=os.path.basename(science_image), + target_ra=target['ra'], target_dec=target['decl'], match=match_threshold, kernel_radius=kernel_radius, + pixel_scale=pixel_scale, mscale=mscale, order=1) + logger.info("Executing command: %s", cmd) + + # Run the command in a subprocess: + cmd = shlex.split(cmd) + proc = subprocess.Popen(cmd, cwd=tmpdir, stdout=subprocess.PIPE, stderr=subprocess.PIPE, + universal_newlines=True) + stdout_data, stderr_data = proc.communicate() + returncode = proc.returncode + proc.kill() # Cleanup - Is this really needed? + + # Check the outputs from the subprocess: + logger.info("Return code: %d", returncode) + logger.info("STDOUT:\n%s", stdout_data.strip()) + if stderr_data.strip() != '': + logger.error("STDERR:\n%s", stderr_data.strip()) + if returncode < 0: + raise Exception("ImageMatch failed. Processed killed by OS with returncode %d." % returncode) + elif 'Failed object match... giving up.' in stdout_data: + # raise Exception("ImageMatch giving up matching objects") + continue + elif returncode > 0: + raise Exception("ImageMatch failed.") + + # Load the resulting difference image into memory: + diffimg_name = re.sub(r'\.fits(\.gz|\.bz2)?$', r'diff.fits\1', os.path.basename(science_image)) + diffimg_path = os.path.join(tmpdir, diffimg_name) + if not os.path.isfile(diffimg_path): + raise FileNotFoundError(diffimg_path) + + break + + else: + raise Exception("ImageMatch could not create difference image.") + + with fits.open(diffimg_path, mode='readonly') as hdu: + diffimg = np.asarray(hdu[0].data) + + return diffimg diff --git a/flows/tns.py b/flows/tns.py index df3f190..004de12 100644 --- a/flows/tns.py +++ b/flows/tns.py @@ -20,268 +20,221 @@ url_tns_api = 'https://www.wis-tns.org/api/get' url_tns_search = 'https://www.wis-tns.org/search' -#-------------------------------------------------------------------------------------------------- + +# -------------------------------------------------------------------------------------------------- class TNSConfigError(RuntimeError): - pass + pass -#-------------------------------------------------------------------------------------------------- -def _load_tns_config(): - logger = logging.getLogger(__name__) - - config = load_config() - api_key = config.get('TNS', 'api_key', fallback=None) - if api_key is None: - raise TNSConfigError("No TNS API-KEY has been defined in config") - - tns_bot_id = config.getint('TNS', 'bot_id', fallback=93222) - tns_bot_name = config.get('TNS', 'bot_name', fallback='AUFLOWS_BOT') - tns_user_id = config.getint('TNS', 'user_id', fallback=None) - tns_user_name = config.get('TNS', 'user_name', fallback=None) - - if tns_user_id and tns_user_name: - logger.debug('Using TNS credentials: user=%s', tns_user_name) - user_agent = 'tns_marker{"tns_id":' + str(tns_user_id) + ',"type":"user","name":"' + tns_user_name + '"}' - elif tns_bot_id and tns_bot_name: - logger.debug('Using TNS credentials: bot=%s', tns_bot_name) - user_agent = 'tns_marker{"tns_id":' + str(tns_bot_id) + ',"type":"bot","name":"' + tns_bot_name + '"}' - else: - raise TNSConfigError("No TNS bot_id or bot_name has been defined in config") - - return { - 'api-key': api_key, - 'user-agent': user_agent - } - -#-------------------------------------------------------------------------------------------------- -def tns_search(coord=None, radius=3*u.arcsec, objname=None, internal_name=None): - """ - Cone-search TNS for object near coordinate. - - Parameters: - coord (:class:`astropy.coordinates.SkyCoord`): Central coordinate to search around. - radius (Angle, optional): Radius to search around ``coord``. - objname (str, optional): Search on object name. - internal_name (str, optional): Search on internal name. - - Returns: - dict: Dictionary with TSN response. - """ - - # API key for Bot - tnsconf = _load_tns_config() - - # change json_list to json format - json_file = { - 'radius': radius.to('arcsec').value, - 'units': 'arcsec', - 'objname': objname, - 'internal_name': internal_name - } - if coord: - json_file['ra'] = coord.icrs.ra.deg - json_file['dec'] = coord.icrs.dec.deg - - # construct the list of (key,value) pairs - headers = {'user-agent': tnsconf['user-agent']} - search_data = [ - ('api_key', (None, tnsconf['api-key'])), - ('data', (None, json.dumps(json_file))) - ] - - # search obj using request module - res = requests.post(url_tns_api + '/search', files=search_data, headers=headers) - res.raise_for_status() - parsed = res.json() - data = parsed['data'] - - if 'reply' in data: - return data['reply'] - return None - -#-------------------------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------------------------- +def _load_tns_config(): + logger = logging.getLogger(__name__) + + config = load_config() + api_key = config.get('TNS', 'api_key', fallback=None) + if api_key is None: + raise TNSConfigError("No TNS API-KEY has been defined in config") + + tns_bot_id = config.getint('TNS', 'bot_id', fallback=93222) + tns_bot_name = config.get('TNS', 'bot_name', fallback='AUFLOWS_BOT') + tns_user_id = config.getint('TNS', 'user_id', fallback=None) + tns_user_name = config.get('TNS', 'user_name', fallback=None) + + if tns_user_id and tns_user_name: + logger.debug('Using TNS credentials: user=%s', tns_user_name) + user_agent = 'tns_marker{"tns_id":' + str(tns_user_id) + ',"type":"user","name":"' + tns_user_name + '"}' + elif tns_bot_id and tns_bot_name: + logger.debug('Using TNS credentials: bot=%s', tns_bot_name) + user_agent = 'tns_marker{"tns_id":' + str(tns_bot_id) + ',"type":"bot","name":"' + tns_bot_name + '"}' + else: + raise TNSConfigError("No TNS bot_id or bot_name has been defined in config") + + return {'api-key': api_key, 'user-agent': user_agent} + + +# -------------------------------------------------------------------------------------------------- +def tns_search(coord=None, radius=3 * u.arcsec, objname=None, internal_name=None): + """ + Cone-search TNS for object near coordinate. + + Parameters: + coord (:class:`astropy.coordinates.SkyCoord`): Central coordinate to search around. + radius (Angle, optional): Radius to search around ``coord``. + objname (str, optional): Search on object name. + internal_name (str, optional): Search on internal name. + + Returns: + dict: Dictionary with TSN response. + """ + + # API key for Bot + tnsconf = _load_tns_config() + + # change json_list to json format + json_file = {'radius': radius.to('arcsec').value, 'units': 'arcsec', 'objname': objname, + 'internal_name': internal_name} + if coord: + json_file['ra'] = coord.icrs.ra.deg + json_file['dec'] = coord.icrs.dec.deg + + # construct the list of (key,value) pairs + headers = {'user-agent': tnsconf['user-agent']} + search_data = [('api_key', (None, tnsconf['api-key'])), ('data', (None, json.dumps(json_file)))] + + # search obj using request module + res = requests.post(url_tns_api + '/search', files=search_data, headers=headers) + res.raise_for_status() + parsed = res.json() + data = parsed['data'] + + if 'reply' in data: + return data['reply'] + return None + + +# -------------------------------------------------------------------------------------------------- def tns_get_obj(name): - """ - Search TNS for object by name. - - Parameters: - name (str): Object name to search for. - - Returns: - dict: Dictionary with TSN response. - """ - - # API key for Bot - tnsconf = _load_tns_config() - - # construct the list of (key,value) pairs - headers = {'user-agent': tnsconf['user-agent']} - params = {'objname': name, 'photometry': '0', 'spectra': '0'} - get_data = [ - ('api_key', (None, tnsconf['api-key'])), - ('data', (None, json.dumps(params))) - ] - - # get obj using request module - res = requests.post(url_tns_api + '/object', files=get_data, headers=headers) - res.raise_for_status() - parsed = res.json() - data = parsed['data'] - - if 'reply' in data: - reply = data['reply'] - if not reply: - return None - if 'objname' not in reply: # Bit of a cheat, but it is simple and works - return None - - reply['internal_names'] = [name.strip() for name in reply['internal_names'].split(',') if name.strip()] - return reply - return None - -#-------------------------------------------------------------------------------------------------- + """ + Search TNS for object by name. + + Parameters: + name (str): Object name to search for. + + Returns: + dict: Dictionary with TSN response. + """ + + # API key for Bot + tnsconf = _load_tns_config() + + # construct the list of (key,value) pairs + headers = {'user-agent': tnsconf['user-agent']} + params = {'objname': name, 'photometry': '0', 'spectra': '0'} + get_data = [('api_key', (None, tnsconf['api-key'])), ('data', (None, json.dumps(params)))] + + # get obj using request module + res = requests.post(url_tns_api + '/object', files=get_data, headers=headers) + res.raise_for_status() + parsed = res.json() + data = parsed['data'] + + if 'reply' in data: + reply = data['reply'] + if not reply: + return None + if 'objname' not in reply: # Bit of a cheat, but it is simple and works + return None + + reply['internal_names'] = [name.strip() for name in reply['internal_names'].split(',') if name.strip()] + return reply + return None + + +# -------------------------------------------------------------------------------------------------- def tns_getnames(months=None, date_begin=None, date_end=None, zmin=None, zmax=None, objtype=[3, 104]): - """ - Get SN names from TNS. - - Parameters: - months (int, optional): Only return objects reported within the last X months. - date_begin (date, optional): Discovery date begin. - date_end (date, optional): Discovery date end. - zmin (float, optional): Minimum redshift. - zmax (float, optional): Maximum redshift. - objtype (list, optional): Constraint object type. - Default is to query for - - 3: SN Ia - - 104: SN Ia-91T-like - - Returns: - list: List of names fulfilling search criteria. - """ - - logger = logging.getLogger(__name__) - - # Change formats of input to be ready for query: - if isinstance(date_begin, datetime.datetime): - date_begin = date_begin.date() - elif isinstance(date_begin, str): - date_begin = datetime.datetime.strptime(date_begin, '%Y-%m-%d').date() - - if isinstance(date_end, datetime.datetime): - date_end = date_end.date() - elif isinstance(date_end, str): - date_end = datetime.datetime.strptime(date_end, '%Y-%m-%d').date() - - if isinstance(objtype, (list, tuple)): - objtype = ','.join([str(o) for o in objtype]) - - # Do some sanity checks: - if date_end < date_begin: - raise ValueError("Dates are in the wrong order.") - - date_now = datetime.datetime.now(datetime.timezone.utc).date() - if months is not None and date_end is not None and date_end < date_now - datetime.timedelta(days=months*30): - logger.warning('Months limit restricts days_begin, consider increasing limit_months.') - - # API key for Bot - tnsconf = _load_tns_config() - - # Parameters for query: - params = { - 'discovered_period_value': months, # Reported Within The Last - 'discovered_period_units': 'months', - 'unclassified_at': 0, # Limit to unclasssified ATs - 'classified_sne': 1, # Limit to classified SNe - 'include_frb': 0, # Include FRBs - #'name': , - 'name_like': 0, - 'isTNS_AT': 'all', - 'public': 'all', - #'ra': - #'decl': - #'radius': - #'coords_unit': 'arcsec', - 'reporting_groupid[]': 'null', - 'groupid[]': 'null', - 'classifier_groupid[]': 'null', - 'objtype[]': objtype, - 'at_type[]': 'null', - 'date_start[date]': date_begin.isoformat(), - 'date_end[date]': date_end.isoformat(), - #'discovery_mag_min': - #'discovery_mag_max': - #'internal_name': - #'discoverer': - #'classifier': - #'spectra_count': - 'redshift_min': zmin, - 'redshift_max': zmax, - #'hostname': - #'ext_catid': - #'ra_range_min': - #'ra_range_max': - #'decl_range_min': - #'decl_range_max': - 'discovery_instrument[]': 'null', - 'classification_instrument[]': 'null', - 'associated_groups[]': 'null', - #'at_rep_remarks': - #'class_rep_remarks': - #'frb_repeat': 'all' - #'frb_repeater_of_objid': - 'frb_measured_redshift': 0, - #'frb_dm_range_min': - #'frb_dm_range_max': - #'frb_rm_range_min': - #'frb_rm_range_max': - #'frb_snr_range_min': - #'frb_snr_range_max': - #'frb_flux_range_min': - #'frb_flux_range_max': - 'num_page': 500, - 'display[redshift]': 0, - 'display[hostname]': 0, - 'display[host_redshift]': 0, - 'display[source_group_name]': 0, - 'display[classifying_source_group_name]': 0, - 'display[discovering_instrument_name]': 0, - 'display[classifing_instrument_name]': 0, - 'display[programs_name]': 0, - 'display[internal_name]': 0, - 'display[isTNS_AT]': 0, - 'display[public]': 0, - 'display[end_pop_period]': 0, - 'display[spectra_count]': 0, - 'display[discoverymag]': 0, - 'display[discmagfilter]': 0, - 'display[discoverydate]': 0, - 'display[discoverer]': 0, - 'display[remarks]': 0, - 'display[sources]': 0, - 'display[bibcode]': 0, - 'display[ext_catalogs]': 0, - 'format': 'csv' - } - - # Query TNS for names: - headers = {'user-agent': tnsconf['user-agent']} - con = requests.get(url_tns_search, params=params, headers=headers) - con.raise_for_status() - - # Parse the CSV table: - # Ensure that there is a newline in table string. - # AstroPy uses this to distinguish file-paths from pure-string inputs: - text = str(con.text) + "\n" - tab = Table.read(text, - format='ascii.csv', - guess=False, - delimiter=',', - quotechar='"', - header_start=0, - data_start=1) - - # Pull out the names only if they begin with "SN": - names_list = [name.replace(' ', '') for name in tab['Name'] if name.startswith('SN')] - names_list = sorted(names_list) - - return names_list + """ + Get SN names from TNS. + + Parameters: + months (int, optional): Only return objects reported within the last X months. + date_begin (date, optional): Discovery date begin. + date_end (date, optional): Discovery date end. + zmin (float, optional): Minimum redshift. + zmax (float, optional): Maximum redshift. + objtype (list, optional): Constraint object type. + Default is to query for + - 3: SN Ia + - 104: SN Ia-91T-like + + Returns: + list: List of names fulfilling search criteria. + """ + + logger = logging.getLogger(__name__) + + # Change formats of input to be ready for query: + if isinstance(date_begin, datetime.datetime): + date_begin = date_begin.date() + elif isinstance(date_begin, str): + date_begin = datetime.datetime.strptime(date_begin, '%Y-%m-%d').date() + + if isinstance(date_end, datetime.datetime): + date_end = date_end.date() + elif isinstance(date_end, str): + date_end = datetime.datetime.strptime(date_end, '%Y-%m-%d').date() + + if isinstance(objtype, (list, tuple)): + objtype = ','.join([str(o) for o in objtype]) + + # Do some sanity checks: + if date_end < date_begin: + raise ValueError("Dates are in the wrong order.") + + date_now = datetime.datetime.now(datetime.timezone.utc).date() + if months is not None and date_end is not None and date_end < date_now - datetime.timedelta(days=months * 30): + logger.warning('Months limit restricts days_begin, consider increasing limit_months.') + + # API key for Bot + tnsconf = _load_tns_config() + + # Parameters for query: + params = {'discovered_period_value': months, # Reported Within The Last + 'discovered_period_units': 'months', 'unclassified_at': 0, # Limit to unclasssified ATs + 'classified_sne': 1, # Limit to classified SNe + 'include_frb': 0, # Include FRBs + # 'name': , + 'name_like': 0, 'isTNS_AT': 'all', 'public': 'all', # 'ra': + # 'decl': + # 'radius': + # 'coords_unit': 'arcsec', + 'reporting_groupid[]': 'null', 'groupid[]': 'null', 'classifier_groupid[]': 'null', 'objtype[]': objtype, + 'at_type[]': 'null', 'date_start[date]': date_begin.isoformat(), 'date_end[date]': date_end.isoformat(), + # 'discovery_mag_min': + # 'discovery_mag_max': + # 'internal_name': + # 'discoverer': + # 'classifier': + # 'spectra_count': + 'redshift_min': zmin, 'redshift_max': zmax, # 'hostname': + # 'ext_catid': + # 'ra_range_min': + # 'ra_range_max': + # 'decl_range_min': + # 'decl_range_max': + 'discovery_instrument[]': 'null', 'classification_instrument[]': 'null', 'associated_groups[]': 'null', + # 'at_rep_remarks': + # 'class_rep_remarks': + # 'frb_repeat': 'all' + # 'frb_repeater_of_objid': + 'frb_measured_redshift': 0, # 'frb_dm_range_min': + # 'frb_dm_range_max': + # 'frb_rm_range_min': + # 'frb_rm_range_max': + # 'frb_snr_range_min': + # 'frb_snr_range_max': + # 'frb_flux_range_min': + # 'frb_flux_range_max': + 'num_page': 500, 'display[redshift]': 0, 'display[hostname]': 0, 'display[host_redshift]': 0, + 'display[source_group_name]': 0, 'display[classifying_source_group_name]': 0, + 'display[discovering_instrument_name]': 0, 'display[classifing_instrument_name]': 0, + 'display[programs_name]': 0, 'display[internal_name]': 0, 'display[isTNS_AT]': 0, 'display[public]': 0, + 'display[end_pop_period]': 0, 'display[spectra_count]': 0, 'display[discoverymag]': 0, + 'display[discmagfilter]': 0, 'display[discoverydate]': 0, 'display[discoverer]': 0, 'display[remarks]': 0, + 'display[sources]': 0, 'display[bibcode]': 0, 'display[ext_catalogs]': 0, 'format': 'csv'} + + # Query TNS for names: + headers = {'user-agent': tnsconf['user-agent']} + con = requests.get(url_tns_search, params=params, headers=headers) + con.raise_for_status() + + # Parse the CSV table: + # Ensure that there is a newline in table string. + # AstroPy uses this to distinguish file-paths from pure-string inputs: + text = str(con.text) + "\n" + tab = Table.read(text, format='ascii.csv', guess=False, delimiter=',', quotechar='"', header_start=0, data_start=1) + + # Pull out the names only if they begin with "SN": + names_list = [name.replace(' ', '') for name in tab['Name'] if name.startswith('SN')] + names_list = sorted(names_list) + + return names_list diff --git a/flows/utilities.py b/flows/utilities.py index e65054f..9e956ba 100644 --- a/flows/utilities.py +++ b/flows/utilities.py @@ -8,19 +8,20 @@ import hashlib -#-------------------------------------------------------------------------------------------------- + +# -------------------------------------------------------------------------------------------------- def get_filehash(fname): - """Calculate SHA1-hash of file.""" - buf = 65536 - s = hashlib.sha1() - with open(fname, 'rb') as fid: - while True: - data = fid.read(buf) - if not data: - break - s.update(data) + """Calculate SHA1-hash of file.""" + buf = 65536 + s = hashlib.sha1() + with open(fname, 'rb') as fid: + while True: + data = fid.read(buf) + if not data: + break + s.update(data) - sha1sum = s.hexdigest().lower() - if len(sha1sum) != 40: - raise Exception("Invalid file hash") - return sha1sum + sha1sum = s.hexdigest().lower() + if len(sha1sum) != 40: + raise Exception("Invalid file hash") + return sha1sum diff --git a/flows/version.py b/flows/version.py index f746921..bd28768 100644 --- a/flows/version.py +++ b/flows/version.py @@ -30,128 +30,130 @@ # Find the "git" command to run depending on the OS: GIT_COMMAND = "git" if name == "nt": - def find_git_on_windows(): - """find the path to the git executable on windows""" - # first see if git is in the path - try: - check_output(["where", "/Q", "git"]) - # if this command succeeded, git is in the path - return "git" - # catch the exception thrown if git was not found - except CalledProcessError: - pass - # There are several locations git.exe may be hiding - possible_locations = [] - # look in program files for msysgit - if "PROGRAMFILES(X86)" in environ: - possible_locations.append("%s/Git/cmd/git.exe" % environ["PROGRAMFILES(X86)"]) - if "PROGRAMFILES" in environ: - possible_locations.append("%s/Git/cmd/git.exe" % environ["PROGRAMFILES"]) - # look for the github version of git - if "LOCALAPPDATA" in environ: - github_dir = "%s/GitHub" % environ["LOCALAPPDATA"] - if path.isdir(github_dir): - for subdir in listdir(github_dir): - if not subdir.startswith("PortableGit"): - continue - possible_locations.append("%s/%s/bin/git.exe" % (github_dir, subdir)) - for possible_location in possible_locations: - if path.isfile(possible_location): - return possible_location - # git was not found - return "git" - - GIT_COMMAND = find_git_on_windows() + def find_git_on_windows(): + """find the path to the git executable on windows""" + # first see if git is in the path + try: + check_output(["where", "/Q", "git"]) + # if this command succeeded, git is in the path + return "git" + # catch the exception thrown if git was not found + except CalledProcessError: + pass + # There are several locations git.exe may be hiding + possible_locations = [] + # look in program files for msysgit + if "PROGRAMFILES(X86)" in environ: + possible_locations.append("%s/Git/cmd/git.exe" % environ["PROGRAMFILES(X86)"]) + if "PROGRAMFILES" in environ: + possible_locations.append("%s/Git/cmd/git.exe" % environ["PROGRAMFILES"]) + # look for the github version of git + if "LOCALAPPDATA" in environ: + github_dir = "%s/GitHub" % environ["LOCALAPPDATA"] + if path.isdir(github_dir): + for subdir in listdir(github_dir): + if not subdir.startswith("PortableGit"): + continue + possible_locations.append("%s/%s/bin/git.exe" % (github_dir, subdir)) + for possible_location in possible_locations: + if path.isfile(possible_location): + return possible_location + # git was not found + return "git" + + + GIT_COMMAND = find_git_on_windows() def call_git_describe(abbrev=7): - """return the string output of git desribe""" - try: - with open(devnull, "w") as fnull: - arguments = [GIT_COMMAND, "describe", "--tags", "--abbrev=%d" % abbrev] - return check_output(arguments, cwd=CURRENT_DIRECTORY, - stderr=fnull).decode("ascii").strip() - except (OSError, CalledProcessError): - return None + """return the string output of git desribe""" + try: + with open(devnull, "w") as fnull: + arguments = [GIT_COMMAND, "describe", "--tags", "--abbrev=%d" % abbrev] + return check_output(arguments, cwd=CURRENT_DIRECTORY, stderr=fnull).decode("ascii").strip() + except (OSError, CalledProcessError): + return None + def call_git_getbranch(): - try: - with open(devnull, "w") as fnull: - arguments = [GIT_COMMAND, "symbolic-ref", "--short", "HEAD"] - return check_output(arguments, cwd=CURRENT_DIRECTORY, - stderr=fnull).decode("ascii").strip() - except (OSError, CalledProcessError): - return None + try: + with open(devnull, "w") as fnull: + arguments = [GIT_COMMAND, "symbolic-ref", "--short", "HEAD"] + return check_output(arguments, cwd=CURRENT_DIRECTORY, stderr=fnull).decode("ascii").strip() + except (OSError, CalledProcessError): + return None + def format_git_describe(git_str, pep440=False): - """format the result of calling 'git describe' as a python version""" - if git_str is None: - return None - if "-" not in git_str: # currently at a tag - return git_str - else: - # formatted as version-N-githash - # want to convert to version.postN-githash - git_str = git_str.replace("-", ".post", 1) - if pep440: # does not allow git hash afterwards - return git_str.split("-")[0] - else: - return git_str.replace("-g", "+git") + """format the result of calling 'git describe' as a python version""" + if git_str is None: + return None + if "-" not in git_str: # currently at a tag + return git_str + else: + # formatted as version-N-githash + # want to convert to version.postN-githash + git_str = git_str.replace("-", ".post", 1) + if pep440: # does not allow git hash afterwards + return git_str.split("-")[0] + else: + return git_str.replace("-g", "+git") + def read_release_version(): - """Read version information from VERSION file""" - try: - with open(VERSION_FILE, "r") as infile: - version = str(infile.read().strip()) - if len(version) == 0: - version = None - return version - except IOError: - return None + """Read version information from VERSION file""" + try: + with open(VERSION_FILE, "r") as infile: + version = str(infile.read().strip()) + if len(version) == 0: + version = None + return version + except IOError: + return None def update_release_version(): - """Update VERSION file""" - version = get_version(pep440=True) - with open(VERSION_FILE, "w") as outfile: - outfile.write(version) + """Update VERSION file""" + version = get_version(pep440=True) + with open(VERSION_FILE, "w") as outfile: + outfile.write(version) def get_version(pep440=False, include_branch=True): - """ - Tracks the version number. + """ + Tracks the version number. - The file VERSION holds the version information. If this is not a git - repository, then it is reasonable to assume that the version is not - being incremented and the version returned will be the release version as - read from the file. + The file VERSION holds the version information. If this is not a git + repository, then it is reasonable to assume that the version is not + being incremented and the version returned will be the release version as + read from the file. - However, if the script is located within an active git repository, - git-describe is used to get the version information. + However, if the script is located within an active git repository, + git-describe is used to get the version information. - The file VERSION will need to be changed by manually. This should be done - before running git tag (set to the same as the version in the tag). + The file VERSION will need to be changed by manually. This should be done + before running git tag (set to the same as the version in the tag). - Parameters: - pep440 (bool): When True, this function returns a version string suitable for - a release as defined by PEP 440. When False, the githash (if - available) will be appended to the version string. + Parameters: + pep440 (bool): When True, this function returns a version string suitable for + a release as defined by PEP 440. When False, the githash (if + available) will be appended to the version string. - Returns: - string: Version sting. - """ + Returns: + string: Version sting. + """ - git_version = format_git_describe(call_git_describe(), pep440=pep440) - if git_version is None: # not a git repository - return read_release_version() + git_version = format_git_describe(call_git_describe(), pep440=pep440) + if git_version is None: # not a git repository + return read_release_version() - if include_branch: - git_branch = call_git_getbranch() - if git_branch is not None: - git_version = git_branch + '-' + git_version + if include_branch: + git_branch = call_git_getbranch() + if git_branch is not None: + git_version = git_branch + '-' + git_version - return git_version + return git_version if __name__ == "__main__": - print(get_version()) + print(get_version()) diff --git a/flows/visibility.py b/flows/visibility.py index 667a219..f18162c 100644 --- a/flows/visibility.py +++ b/flows/visibility.py @@ -18,116 +18,119 @@ from astropy.visualization import quantity_support from . import api -#-------------------------------------------------------------------------------------------------- + +# -------------------------------------------------------------------------------------------------- def visibility(target, siteid=None, date=None, output=None, overwrite=True): - """ - Create visibility plot. - - Parameters: - target (str or int): - siteid (int): Identifier of site. - date (datetime or str, optional): Date for which to create visibility plot. - Default it to use the current date. - output (str, optional): Path to file or directory where to place the plot. - If not given, the plot will be created in memory, and can be shown on screen. - overwrite (bool, optional): Should existing file specified in ``output`` be overwritten? - Default is to overwrite an existing file. - - .. codeauthor:: Rasmus Handberg - """ - - logger = logging.getLogger(__name__) - - if date is None: - date = datetime.utcnow() - elif isinstance(date, str): - date = datetime.strptime(date, '%Y-%m-%d') - - tgt = api.get_target(target) - - # Coordinates of object: - obj = SkyCoord(ra=tgt['ra'], dec=tgt['decl'], unit='deg', frame='icrs') - - if siteid is None: - sites = api.get_all_sites() - else: - sites = [api.get_site(siteid)] - - plotpaths = [] - for site in sites: - # If we are saving plot to file, determine the path to save to - # and check that it doesn't already exist: - if output: - if os.path.isdir(output): - plotpath = os.path.join(output, "visibility_%s_%s_site%02d.png" % ( - tgt['target_name'], - date.strftime('%Y%m%d'), - site['siteid'])) - else: - plotpath = output - logger.debug("Will save visibility plot to '%s'", plotpath) - - # If we are not overwriting and - if not overwrite and os.path.exists(plotpath): - logger.info("File already exists: %s", plotpath) - continue - - # Observatory: - observatory = site['EarthLocation'] - utcoffset = (site['longitude']*u.deg/(360*u.deg)) * 24*u.hour - - # Create timestamps to calculate for: - midnight = Time(date.strftime('%Y-%m-%d') + ' 00:00:00', scale='utc') - utcoffset - delta_midnight = np.linspace(-12, 12, 1000)*u.hour - times = midnight + delta_midnight - - # AltAz frame: - AltAzFrame = AltAz(obstime=times, location=observatory) - - # Object: - altaz_obj = obj.transform_to(AltAzFrame) - - # The Sun and Moon: - altaz_sun = get_sun(times).transform_to(AltAzFrame) - altaz_moon = get_moon(times).transform_to(AltAzFrame) - - sundown_astro = (altaz_sun.alt < -6*u.deg) - if np.any(sundown_astro): - min_time = np.min(times[sundown_astro]) - 2*u.hour - max_time = np.max(times[sundown_astro]) + 2*u.hour - else: - min_time = times[0] - max_time = times[-1] - - quantity_support() - fig, ax = plt.subplots(1, 1, figsize=(15,9), squeeze=True) - plt.grid(ls=':', lw=0.5) - ax.plot(times.datetime, altaz_sun.alt, color='y', label='Sun') - ax.plot(times.datetime, altaz_moon.alt, color=[0.75]*3, ls='--', label='Moon') - objsc = ax.scatter(times.datetime, altaz_obj.alt, c=altaz_obj.az, label=tgt['target_name'], lw=0, s=8, cmap='twilight') - ax.fill_between(times.datetime, 0*u.deg, 90*u.deg, altaz_sun.alt < -0*u.deg, color='0.5', zorder=0) # , label='Night' - ax.fill_between(times.datetime, 0*u.deg, 90*u.deg, altaz_sun.alt < -18*u.deg, color='k', zorder=0) # , label='Astronomical Night' - - plt.colorbar(objsc, ax=ax, pad=0.01).set_label('Azimuth [deg]') - ax.legend(loc='upper left') - ax.minorticks_on() - ax.set_xlim(min_time.datetime, max_time.datetime) - ax.set_ylim(0*u.deg, 90*u.deg) - ax.set_title("%s - %s - %s" % (str(tgt['target_name']), date.strftime('%Y-%m-%d'), site['sitename']), fontsize=14) - plt.xlabel('Time [UTC]', fontsize=14) - plt.ylabel('Altitude [deg]', fontsize=16) - fig.autofmt_xdate() - - formatter = DateFormatter('%d/%m %H:%M') - ax.xaxis.set_major_formatter(formatter) - - if output: - fig.savefig(plotpath, bbox_inches='tight', transparent=True) - plt.close(fig) - plotpaths.append(plotpath) - - if output: - return plotpaths - - plt.show() - return ax + """ + Create visibility plot. + + Parameters: + target (str or int): + siteid (int): Identifier of site. + date (datetime or str, optional): Date for which to create visibility plot. + Default it to use the current date. + output (str, optional): Path to file or directory where to place the plot. + If not given, the plot will be created in memory, and can be shown on screen. + overwrite (bool, optional): Should existing file specified in ``output`` be overwritten? + Default is to overwrite an existing file. + + .. codeauthor:: Rasmus Handberg + """ + + logger = logging.getLogger(__name__) + + if date is None: + date = datetime.utcnow() + elif isinstance(date, str): + date = datetime.strptime(date, '%Y-%m-%d') + + tgt = api.get_target(target) + + # Coordinates of object: + obj = SkyCoord(ra=tgt['ra'], dec=tgt['decl'], unit='deg', frame='icrs') + + if siteid is None: + sites = api.get_all_sites() + else: + sites = [api.get_site(siteid)] + + plotpaths = [] + for site in sites: + # If we are saving plot to file, determine the path to save to + # and check that it doesn't already exist: + if output: + if os.path.isdir(output): + plotpath = os.path.join(output, "visibility_%s_%s_site%02d.png" % ( + tgt['target_name'], date.strftime('%Y%m%d'), site['siteid'])) + else: + plotpath = output + logger.debug("Will save visibility plot to '%s'", plotpath) + + # If we are not overwriting and + if not overwrite and os.path.exists(plotpath): + logger.info("File already exists: %s", plotpath) + continue + + # Observatory: + observatory = site['EarthLocation'] + utcoffset = (site['longitude'] * u.deg / (360 * u.deg)) * 24 * u.hour + + # Create timestamps to calculate for: + midnight = Time(date.strftime('%Y-%m-%d') + ' 00:00:00', scale='utc') - utcoffset + delta_midnight = np.linspace(-12, 12, 1000) * u.hour + times = midnight + delta_midnight + + # AltAz frame: + AltAzFrame = AltAz(obstime=times, location=observatory) + + # Object: + altaz_obj = obj.transform_to(AltAzFrame) + + # The Sun and Moon: + altaz_sun = get_sun(times).transform_to(AltAzFrame) + altaz_moon = get_moon(times).transform_to(AltAzFrame) + + sundown_astro = (altaz_sun.alt < -6 * u.deg) + if np.any(sundown_astro): + min_time = np.min(times[sundown_astro]) - 2 * u.hour + max_time = np.max(times[sundown_astro]) + 2 * u.hour + else: + min_time = times[0] + max_time = times[-1] + + quantity_support() + fig, ax = plt.subplots(1, 1, figsize=(15, 9), squeeze=True) + plt.grid(ls=':', lw=0.5) + ax.plot(times.datetime, altaz_sun.alt, color='y', label='Sun') + ax.plot(times.datetime, altaz_moon.alt, color=[0.75] * 3, ls='--', label='Moon') + objsc = ax.scatter(times.datetime, altaz_obj.alt, c=altaz_obj.az, label=tgt['target_name'], lw=0, s=8, + cmap='twilight') + ax.fill_between(times.datetime, 0 * u.deg, 90 * u.deg, altaz_sun.alt < -0 * u.deg, color='0.5', + zorder=0) # , label='Night' + ax.fill_between(times.datetime, 0 * u.deg, 90 * u.deg, altaz_sun.alt < -18 * u.deg, color='k', + zorder=0) # , label='Astronomical Night' + + plt.colorbar(objsc, ax=ax, pad=0.01).set_label('Azimuth [deg]') + ax.legend(loc='upper left') + ax.minorticks_on() + ax.set_xlim(min_time.datetime, max_time.datetime) + ax.set_ylim(0 * u.deg, 90 * u.deg) + ax.set_title("%s - %s - %s" % (str(tgt['target_name']), date.strftime('%Y-%m-%d'), site['sitename']), + fontsize=14) + plt.xlabel('Time [UTC]', fontsize=14) + plt.ylabel('Altitude [deg]', fontsize=16) + fig.autofmt_xdate() + + formatter = DateFormatter('%d/%m %H:%M') + ax.xaxis.set_major_formatter(formatter) + + if output: + fig.savefig(plotpath, bbox_inches='tight', transparent=True) + plt.close(fig) + plotpaths.append(plotpath) + + if output: + return plotpaths + + plt.show() + return ax diff --git a/flows/zeropoint.py b/flows/zeropoint.py index 39669da..486f7b2 100644 --- a/flows/zeropoint.py +++ b/flows/zeropoint.py @@ -15,54 +15,53 @@ from scipy.special import erfcinv -#Calculate sigma for sigma clipping using Chauvenet +# Calculate sigma for sigma clipping using Chauvenet def sigma_from_Chauvenet(Nsamples): - '''Calculate sigma according to the Cheuvenet criterion''' - return erfcinv(1./(2*Nsamples)) * (2.)**(1/2) + '''Calculate sigma according to the Cheuvenet criterion''' + return erfcinv(1. / (2 * Nsamples)) * (2.) ** (1 / 2) -def bootstrap_outlier(x,y,yerr, n=500, model='None',fitter='None', - outlier='None', outlier_kwargs={'sigma':3}, summary='median', error='bootstrap', - parnames=['intercept'], return_vals=True): - '''x = catalog mag, y = instrumental mag, yerr = instrumental error - summary = function for summary statistic, np.nanmedian by default. - model = Linear1D - fitter = LinearLSQFitter - outlier = 'sigma_clip' - outlier_kwargs, default sigma = 3 - return_vals = False will return dictionary - Performs bootstrap with replacement and returns model. - ''' - summary = np.nanmedian if summary == 'median' else summary - error = np.nanstd if error == 'bootstrap' else error +def bootstrap_outlier(x, y, yerr, n=500, model='None', fitter='None', outlier='None', outlier_kwargs={'sigma': 3}, + summary='median', error='bootstrap', parnames=['intercept'], return_vals=True): + '''x = catalog mag, y = instrumental mag, yerr = instrumental error + summary = function for summary statistic, np.nanmedian by default. + model = Linear1D + fitter = LinearLSQFitter + outlier = 'sigma_clip' + outlier_kwargs, default sigma = 3 + return_vals = False will return dictionary + Performs bootstrap with replacement and returns model. + ''' + summary = np.nanmedian if summary == 'median' else summary + error = np.nanstd if error == 'bootstrap' else error - #Create index for bootstrapping - ind = np.arange(len(x)) + # Create index for bootstrapping + ind = np.arange(len(x)) - #Bootstrap indexes with replacement using astropy - bootstraps = bootstrap(ind,bootnum=n) - bootstraps.sort() # sort increasing. - bootinds = bootstraps.astype(int) + # Bootstrap indexes with replacement using astropy + bootstraps = bootstrap(ind, bootnum=n) + bootstraps.sort() # sort increasing. + bootinds = bootstraps.astype(int) - #Prepare fitter - fitter_instance = fitting.FittingWithOutlierRemoval(fitter(),outlier, **outlier_kwargs) - #Fit each bootstrap with model and fitter using outlier rejection at each step. - #Then obtain summary statistic for each parameter in parnames - pars = {} - out = {} - for parname in parnames: - pars[parname] = np.ones(len(bootinds), dtype=np.float64) - for i,bs in enumerate(bootinds): - #w = np.ones(len(x[bs]), dtype=np.float64) if yerr=='None' else (1.0/yerr[bs])**2 - w = (1.0/yerr[bs])**2 - best_fit, sigma_clipped = fitter_instance(model, x[bs], y[bs], weights=w) - #obtain parameters of interest - for parname in parnames: - pars[parname][i] = best_fit.parameters[np.array(best_fit.param_names) == parname][0] - if return_vals: - return [summary(pars[par]) for par in pars] + # Prepare fitter + fitter_instance = fitting.FittingWithOutlierRemoval(fitter(), outlier, **outlier_kwargs) + # Fit each bootstrap with model and fitter using outlier rejection at each step. + # Then obtain summary statistic for each parameter in parnames + pars = {} + out = {} + for parname in parnames: + pars[parname] = np.ones(len(bootinds), dtype=np.float64) + for i, bs in enumerate(bootinds): + # w = np.ones(len(x[bs]), dtype=np.float64) if yerr=='None' else (1.0/yerr[bs])**2 + w = (1.0 / yerr[bs]) ** 2 + best_fit, sigma_clipped = fitter_instance(model, x[bs], y[bs], weights=w) + # obtain parameters of interest + for parname in parnames: + pars[parname][i] = best_fit.parameters[np.array(best_fit.param_names) == parname][0] + if return_vals: + return [summary(pars[par]) for par in pars] - for parname in parnames: - out[parname] = summary(pars[parname]) - out[parname+'_error'] = error(pars[parname]) - return out + for parname in parnames: + out[parname] = summary(pars[parname]) + out[parname + '_error'] = error(pars[parname]) + return out diff --git a/flows/ztf.py b/flows/ztf.py index dc74d8e..c532ce2 100644 --- a/flows/ztf.py +++ b/flows/ztf.py @@ -17,133 +17,127 @@ import requests from . import api -#-------------------------------------------------------------------------------------------------- -def query_ztf_id(coo_centre, radius=3*u.arcsec, discovery_date=None): - """ - Query ALeRCE ZTF api to lookup ZTF identifier. - - In case multiple identifiers are found within the search cone, the one - closest to the centre is returned. - - Parameters: - coo_centre (:class:`astropy.coordinates.SkyCoord`): Coordinates of centre of search cone. - radius (Angle, optional): Search radius. Default 3 arcsec. - discovery_date (:class:`astropy.time.Time`, optional): Discovery date of target to - match against ZTF. The date is compared to the ZTF first timestamp and ZTF targets - are rejected if they are not within 15 days prior to the discovery date - and 90 days after. - - Returns: - str: ZTF identifier. - - .. codeauthor:: Emir Karamehmetoglu - .. codeauthor:: Rasmus Handberg - """ - - if isinstance(radius, (float, int)): - radius *= u.deg - - # Make json query for Alerce query API - query = { - 'ra': coo_centre.ra.deg, - 'dec': coo_centre.dec.deg, - 'radius': Angle(radius).arcsec, - 'page_size': 20, - 'count': True - } - - # Run http POST json query to alerce following their API - res = requests.get('https://api.alerce.online/ztf/v1/objects', params=query) - res.raise_for_status() - jsn = res.json() - - # If nothing was found, return None: - if jsn['total'] == 0: - return None - - # Start by removing anything marked as likely stellar-like source: - results = jsn['items'] - results = [itm for itm in results if not itm['stellar']] - if not results: - return None - - # Constrain on the discovery date if it is provided: - if discovery_date is not None: - # Extract the time of the first ZTF timestamp and compare it with - # the discovery time: - firstmjd = Time([itm['firstmjd'] for itm in results], format='mjd', scale='utc') - tdelta = firstmjd.utc.mjd - discovery_date.utc.mjd - - # Only keep results that are within the margins: - results = [itm for k, itm in enumerate(results) if -15 <= tdelta[k] <= 90] - if not results: - return None - - # Find target closest to the centre: - coords = SkyCoord( - ra=[itm['meanra'] for itm in results], - dec=[itm['meandec'] for itm in results], - unit='deg', - frame='icrs') - - indx = np.argmin(coords.separation(coo_centre)) - - return results[indx]['oid'] - -#-------------------------------------------------------------------------------------------------- + +# -------------------------------------------------------------------------------------------------- +def query_ztf_id(coo_centre, radius=3 * u.arcsec, discovery_date=None): + """ + Query ALeRCE ZTF api to lookup ZTF identifier. + + In case multiple identifiers are found within the search cone, the one + closest to the centre is returned. + + Parameters: + coo_centre (:class:`astropy.coordinates.SkyCoord`): Coordinates of centre of search cone. + radius (Angle, optional): Search radius. Default 3 arcsec. + discovery_date (:class:`astropy.time.Time`, optional): Discovery date of target to + match against ZTF. The date is compared to the ZTF first timestamp and ZTF targets + are rejected if they are not within 15 days prior to the discovery date + and 90 days after. + + Returns: + str: ZTF identifier. + + .. codeauthor:: Emir Karamehmetoglu + .. codeauthor:: Rasmus Handberg + """ + + if isinstance(radius, (float, int)): + radius *= u.deg + + # Make json query for Alerce query API + query = {'ra': coo_centre.ra.deg, 'dec': coo_centre.dec.deg, 'radius': Angle(radius).arcsec, 'page_size': 20, + 'count': True} + + # Run http POST json query to alerce following their API + res = requests.get('https://api.alerce.online/ztf/v1/objects', params=query) + res.raise_for_status() + jsn = res.json() + + # If nothing was found, return None: + if jsn['total'] == 0: + return None + + # Start by removing anything marked as likely stellar-like source: + results = jsn['items'] + results = [itm for itm in results if not itm['stellar']] + if not results: + return None + + # Constrain on the discovery date if it is provided: + if discovery_date is not None: + # Extract the time of the first ZTF timestamp and compare it with + # the discovery time: + firstmjd = Time([itm['firstmjd'] for itm in results], format='mjd', scale='utc') + tdelta = firstmjd.utc.mjd - discovery_date.utc.mjd + + # Only keep results that are within the margins: + results = [itm for k, itm in enumerate(results) if -15 <= tdelta[k] <= 90] + if not results: + return None + + # Find target closest to the centre: + coords = SkyCoord(ra=[itm['meanra'] for itm in results], dec=[itm['meandec'] for itm in results], unit='deg', + frame='icrs') + + indx = np.argmin(coords.separation(coo_centre)) + + return results[indx]['oid'] + + +# -------------------------------------------------------------------------------------------------- def download_ztf_photometry(targetid): - """ - Download ZTF photometry from ALERCE API. - - Parameters: - targetid (int): Target identifier. - - Returns: - :class:`astropy.table.Table`: ZTF photometry table. - - .. codeauthor:: Emir Karamehmetoglu - .. codeauthor:: Rasmus Handberg - """ - - # Get target info from Flows API: - tgt = api.get_target(targetid) - oid = tgt['ztf_id'] - target_name = tgt['target_name'] - if oid is None: - return None - - # Query ALERCE for detections of object based on oid - res = requests.get(f'https://api.alerce.online/ztf/v1/objects/{oid:s}/detections') - res.raise_for_status() - jsn = res.json() - - # Create Astropy table, cut out the needed columns - # and rename columns to something better for what we are doing: - tab = Table(data=jsn) - tab = tab[['fid', 'mjd', 'magpsf', 'sigmapsf']] - tab.rename_column('fid', 'photfilter') - tab.rename_column('mjd', 'time') - tab.rename_column('magpsf', 'mag') - tab.rename_column('sigmapsf', 'mag_err') - - # Remove bad values of time and magnitude: - tab['time'] = np.asarray(tab['time'], dtype='float64') - tab['mag'] = np.asarray(tab['mag'], dtype='float64') - tab['mag_err'] = np.asarray(tab['mag_err'], dtype='float64') - indx = np.isfinite(tab['time']) & np.isfinite(tab['mag']) & np.isfinite(tab['mag_err']) - tab = tab[indx] - - # Replace photometric filter numbers with keywords used in Flows: - photfilter_dict = {1: 'gp', 2: 'rp', 3: 'ip'} - tab['photfilter'] = [photfilter_dict[fid] for fid in tab['photfilter']] - - # Sort the table on photfilter and time: - tab.sort(['photfilter', 'time']) - - # Add meta information to table header: - tab.meta['target_name'] = target_name - tab.meta['targetid'] = targetid - tab.meta['ztf_id'] = oid - tab.meta['last_updated'] = datetime.datetime.now(tz=datetime.timezone.utc).isoformat() - - return tab + """ + Download ZTF photometry from ALERCE API. + + Parameters: + targetid (int): Target identifier. + + Returns: + :class:`astropy.table.Table`: ZTF photometry table. + + .. codeauthor:: Emir Karamehmetoglu + .. codeauthor:: Rasmus Handberg + """ + + # Get target info from Flows API: + tgt = api.get_target(targetid) + oid = tgt['ztf_id'] + target_name = tgt['target_name'] + if oid is None: + return None + + # Query ALERCE for detections of object based on oid + res = requests.get(f'https://api.alerce.online/ztf/v1/objects/{oid:s}/detections') + res.raise_for_status() + jsn = res.json() + + # Create Astropy table, cut out the needed columns + # and rename columns to something better for what we are doing: + tab = Table(data=jsn) + tab = tab[['fid', 'mjd', 'magpsf', 'sigmapsf']] + tab.rename_column('fid', 'photfilter') + tab.rename_column('mjd', 'time') + tab.rename_column('magpsf', 'mag') + tab.rename_column('sigmapsf', 'mag_err') + + # Remove bad values of time and magnitude: + tab['time'] = np.asarray(tab['time'], dtype='float64') + tab['mag'] = np.asarray(tab['mag'], dtype='float64') + tab['mag_err'] = np.asarray(tab['mag_err'], dtype='float64') + indx = np.isfinite(tab['time']) & np.isfinite(tab['mag']) & np.isfinite(tab['mag_err']) + tab = tab[indx] + + # Replace photometric filter numbers with keywords used in Flows: + photfilter_dict = {1: 'gp', 2: 'rp', 3: 'ip'} + tab['photfilter'] = [photfilter_dict[fid] for fid in tab['photfilter']] + + # Sort the table on photfilter and time: + tab.sort(['photfilter', 'time']) + + # Add meta information to table header: + tab.meta['target_name'] = target_name + tab.meta['targetid'] = targetid + tab.meta['ztf_id'] = oid + tab.meta['last_updated'] = datetime.datetime.now(tz=datetime.timezone.utc).isoformat() + + return tab diff --git a/notes/disk_covering_problem.py b/notes/disk_covering_problem.py index 02e30dd..6824d49 100644 --- a/notes/disk_covering_problem.py +++ b/notes/disk_covering_problem.py @@ -12,34 +12,31 @@ from matplotlib.patches import Circle import astropy.units as u -#-------------------------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------------------------- if __name__ == '__main__': - coo_centre = SkyCoord(ra=0, dec=0, unit='deg', frame='icrs') + coo_centre = SkyCoord(ra=0, dec=0, unit='deg', frame='icrs') - radius = 24.0/60.0 + radius = 24.0 / 60.0 - #aframe = SkyOffsetFrame(origin=coo_centre) - #c = coo_centre.transform_to(aframe) - #print(c) + # aframe = SkyOffsetFrame(origin=coo_centre) + # c = coo_centre.transform_to(aframe) + # print(c) - fig, ax = plt.subplots() - ax.plot(coo_centre.ra.deg, coo_centre.dec.deg, 'rx') + fig, ax = plt.subplots() + ax.plot(coo_centre.ra.deg, coo_centre.dec.deg, 'rx') - ax.add_artist(Circle([coo_centre.ra.deg, coo_centre.dec.deg], radius=radius, ec='r', fc=None, fill=False)) - ax.add_artist(Circle([coo_centre.ra.deg, coo_centre.dec.deg], radius=0.5*radius, ec='b', fc=None, fill=False)) + ax.add_artist(Circle([coo_centre.ra.deg, coo_centre.dec.deg], radius=radius, ec='r', fc=None, fill=False)) + ax.add_artist(Circle([coo_centre.ra.deg, coo_centre.dec.deg], radius=0.5 * radius, ec='b', fc=None, fill=False)) - for n in range(6): - new = SkyCoord( - ra=coo_centre.ra.deg + 0.8 * radius * np.cos(n*60*np.pi/180), - dec=coo_centre.dec.deg + 0.8 * radius * np.sin(n*60*np.pi/180), - unit='deg', frame='icrs') + for n in range(6): + new = SkyCoord(ra=coo_centre.ra.deg + 0.8 * radius * np.cos(n * 60 * np.pi / 180), + dec=coo_centre.dec.deg + 0.8 * radius * np.sin(n * 60 * np.pi / 180), unit='deg', frame='icrs') - ax.plot(new.ra.deg, new.dec.deg, 'bx') - ax.add_artist(Circle([new.ra.deg, new.dec.deg], radius=0.5*radius, ec='b', fc=None, fill=False)) + ax.plot(new.ra.deg, new.dec.deg, 'bx') + ax.add_artist(Circle([new.ra.deg, new.dec.deg], radius=0.5 * radius, ec='b', fc=None, fill=False)) - - plt.axis('equal') - #ax.set_xlim(coo_centre.ra.deg + radius * np.array([-2, 2])) - #ax.set_ylim(coo_centre.dec.deg +radius * np.array([-2, 2])) - plt.show() + plt.axis('equal') + # ax.set_xlim(coo_centre.ra.deg + radius * np.array([-2, 2])) + # ax.set_ylim(coo_centre.dec.deg +radius * np.array([-2, 2])) + plt.show() diff --git a/notes/fix_ztf_ids.py b/notes/fix_ztf_ids.py index e8df20c..d5268cb 100644 --- a/notes/fix_ztf_ids.py +++ b/notes/fix_ztf_ids.py @@ -5,29 +5,30 @@ import os.path from tqdm import tqdm from astropy.coordinates import SkyCoord + if os.path.abspath('..') not in sys.path: - sys.path.insert(0, os.path.abspath('..')) + sys.path.insert(0, os.path.abspath('..')) import flows -#-------------------------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------------------------- if __name__ == '__main__': - with flows.aadc_db.AADC_DB() as db: - for target in tqdm(flows.api.get_targets()): - if target['target_status'] == 'rejected': - continue + with flows.aadc_db.AADC_DB() as db: + for target in tqdm(flows.api.get_targets()): + if target['target_status'] == 'rejected': + continue - targetid = target['targetid'] - coord = SkyCoord(ra=target['ra'], dec=target['decl'], unit='deg', frame='icrs') - dd = target['discovery_date'] + targetid = target['targetid'] + coord = SkyCoord(ra=target['ra'], dec=target['decl'], unit='deg', frame='icrs') + dd = target['discovery_date'] - # Query for the ZTF id: - ztf_id = flows.ztf.query_ztf_id(coord, discovery_date=dd) + # Query for the ZTF id: + ztf_id = flows.ztf.query_ztf_id(coord, discovery_date=dd) - # If the ZTF id is not the same as we have currently, update it in the database: - if ztf_id != target['ztf_id']: - print(target) - print(ztf_id) - print("******* NEEDS UPDATE ******") + # If the ZTF id is not the same as we have currently, update it in the database: + if ztf_id != target['ztf_id']: + print(target) + print(ztf_id) + print("******* NEEDS UPDATE ******") - db.cursor.execute("UPDATE flows.targets SET ztf_id=%s WHERE targetid=%s;", (ztf_id, targetid)) - db.conn.commit() + db.cursor.execute("UPDATE flows.targets SET ztf_id=%s WHERE targetid=%s;", (ztf_id, targetid)) + db.conn.commit() diff --git a/notes/update_all_catalogs.py b/notes/update_all_catalogs.py index 7356441..3fcd5c0 100644 --- a/notes/update_all_catalogs.py +++ b/notes/update_all_catalogs.py @@ -6,47 +6,51 @@ import os.path import tqdm from astropy.coordinates import SkyCoord + if os.path.abspath('..') not in sys.path: - sys.path.insert(0, os.path.abspath('..')) + sys.path.insert(0, os.path.abspath('..')) import flows -#-------------------------------------------------------------------------------------------------- + +# -------------------------------------------------------------------------------------------------- class TqdmLoggingHandler(logging.Handler): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def emit(self, record): - try: - msg = self.format(record) - tqdm.tqdm.write(msg) - self.flush() - except (KeyboardInterrupt, SystemExit): # pragma: no cover - raise - except: # noqa: E722, pragma: no cover - self.handleError(record) - -#-------------------------------------------------------------------------------------------------- + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def emit(self, record): + try: + msg = self.format(record) + tqdm.tqdm.write(msg) + self.flush() + except (KeyboardInterrupt, SystemExit): # pragma: no cover + raise + except: # noqa: E722, pragma: no cover + self.handleError(record) + + +# -------------------------------------------------------------------------------------------------- if __name__ == '__main__': - # Setup logging: - formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') - console = TqdmLoggingHandler() - console.setFormatter(formatter) - logger = logging.getLogger('flows') - if not logger.hasHandlers(): - logger.addHandler(console) - logger.setLevel(logging.INFO) - - # Do it by status, just to prioritize things a bit: - for tgtstatus in ('target', 'candidate', 'rejected'): - targetids = sorted([tgt['targetid'] for tgt in flows.api.get_targets() if tgt['target_status'] == tgtstatus])[::-1] - - for targetid in tqdm.tqdm(targetids, desc=tgtstatus): - donefile = f"catalog_updates/{targetid:05d}.done" - if not os.path.exists(donefile): - try: - flows.catalogs.download_catalog(targetid, update_existing=True) - except: - logger.exception("targetid=%d", targetid) - else: - open(donefile, 'w').close() + # Setup logging: + formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') + console = TqdmLoggingHandler() + console.setFormatter(formatter) + logger = logging.getLogger('flows') + if not logger.hasHandlers(): + logger.addHandler(console) + logger.setLevel(logging.INFO) + + # Do it by status, just to prioritize things a bit: + for tgtstatus in ('target', 'candidate', 'rejected'): + targetids = sorted([tgt['targetid'] for tgt in flows.api.get_targets() if tgt['target_status'] == tgtstatus])[ + ::-1] + + for targetid in tqdm.tqdm(targetids, desc=tgtstatus): + donefile = f"catalog_updates/{targetid:05d}.done" + if not os.path.exists(donefile): + try: + flows.catalogs.download_catalog(targetid, update_existing=True) + except: + logger.exception("targetid=%d", targetid) + else: + open(donefile, 'w').close() diff --git a/run_catalogs.py b/run_catalogs.py index c73a2d4..49fd13f 100644 --- a/run_catalogs.py +++ b/run_catalogs.py @@ -9,39 +9,39 @@ from flows import api, download_catalog if __name__ == '__main__': - # Parse command line arguments: - parser = argparse.ArgumentParser(description='Run catalog.') - parser.add_argument('-d', '--debug', help='Print debug messages.', action='store_true') - parser.add_argument('-q', '--quiet', help='Only report warnings and errors.', action='store_true') - parser.add_argument('-t', '--target', type=str, help='Target to print catalog for.', nargs='?', default=None) - args = parser.parse_args() - - # Set logging level: - logging_level = logging.INFO - if args.quiet: - logging_level = logging.WARNING - elif args.debug: - logging_level = logging.DEBUG - - # Setup logging: - formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') - console = logging.StreamHandler() - console.setFormatter(formatter) - logger = logging.getLogger('flows') - if not logger.hasHandlers(): - logger.addHandler(console) - logger.setLevel(logging_level) - - for target in api.get_catalog_missing(): - logger.info("Downloading catalog for target=%s...", target) - download_catalog(target) - - if args.target is not None: - cat = api.get_catalog(args.target) - - print("Target:") - cat['target'].pprint_all() - print("\nReferences:") - cat['references'].pprint_all() - print("\nAvoid:") - cat['avoid'].pprint_all() + # Parse command line arguments: + parser = argparse.ArgumentParser(description='Run catalog.') + parser.add_argument('-d', '--debug', help='Print debug messages.', action='store_true') + parser.add_argument('-q', '--quiet', help='Only report warnings and errors.', action='store_true') + parser.add_argument('-t', '--target', type=str, help='Target to print catalog for.', nargs='?', default=None) + args = parser.parse_args() + + # Set logging level: + logging_level = logging.INFO + if args.quiet: + logging_level = logging.WARNING + elif args.debug: + logging_level = logging.DEBUG + + # Setup logging: + formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') + console = logging.StreamHandler() + console.setFormatter(formatter) + logger = logging.getLogger('flows') + if not logger.hasHandlers(): + logger.addHandler(console) + logger.setLevel(logging_level) + + for target in api.get_catalog_missing(): + logger.info("Downloading catalog for target=%s...", target) + download_catalog(target) + + if args.target is not None: + cat = api.get_catalog(args.target) + + print("Target:") + cat['target'].pprint_all() + print("\nReferences:") + cat['references'].pprint_all() + print("\nAvoid:") + cat['avoid'].pprint_all() diff --git a/run_download_ztf.py b/run_download_ztf.py index ea02f55..c996f46 100644 --- a/run_download_ztf.py +++ b/run_download_ztf.py @@ -13,107 +13,108 @@ from flows import ztf, api, load_config from flows.plots import plt -#-------------------------------------------------------------------------------------------------- + +# -------------------------------------------------------------------------------------------------- def main(): - # Parse command line arguments: - parser = argparse.ArgumentParser(description='Download ZTF photometry.') - parser.add_argument('-d', '--debug', help='Print debug messages.', action='store_true') - parser.add_argument('-q', '--quiet', help='Only report warnings and errors.', action='store_true') - parser.add_argument('-t', '--target', type=str, default=None, help='Target to download ZTF photometry for.') - parser.add_argument('-o', '--output', type=str, default=None, help='Directory to save output to.') - args = parser.parse_args() - - # Set logging level: - logging_level = logging.INFO - if args.quiet: - logging_level = logging.WARNING - elif args.debug: - logging_level = logging.DEBUG - - # Setup logging: - formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') - console = logging.StreamHandler() - console.setFormatter(formatter) - logger = logging.getLogger(__name__) - if not logger.hasHandlers(): - logger.addHandler(console) - logger.setLevel(logging_level) - - if args.output is None: - config = load_config() - output_dir = config.get('ztf', 'output_photometry', fallback='.') - else: - output_dir = args.output - logger.info("Saving output to '%s'", output_dir) - - # Check that output directory exists: - if not os.path.isdir(output_dir): - parser.error(f"Output directory does not exist: '{output_dir}'") # noqa: G004 - - # Use API to get list of targets to process: - if args.target is None: - targets = api.get_targets() - else: - targets = [api.get_target(args.target)] - - # Colors used for the different filters in plots: - # I know purple is in the wrong end of the scale, but not much I can do - colors = {'gp': 'tab:green', 'rp': 'tab:red', 'ip': 'tab:purple'} - - # Loop through targets: - for tgt in targets: - logger.debug("Target: %s", tgt) - target_name = tgt['target_name'] - - # Paths to the files to be updated: - ztf_lightcurve_path = os.path.join(output_dir, f'{target_name:s}-ztf.ecsv') - ztf_plot_path = os.path.join(output_dir, f'{target_name:s}-ztf.png') - - # If there is no ZTF id, there is no need to try: - # If an old file exists then delete it. - if tgt['ztf_id'] is None: - if os.path.isfile(ztf_lightcurve_path): - os.remove(ztf_lightcurve_path) - if os.path.isfile(ztf_plot_path): - os.remove(ztf_plot_path) - continue - - # Download ZTF photometry as Astropy Table: - tab = ztf.download_ztf_photometry(tgt['targetid']) - logger.debug("ZTF Photometry:\n%s", tab) - if tab is None or len(tab) == 0: - if os.path.isfile(ztf_lightcurve_path): - os.remove(ztf_lightcurve_path) - if os.path.isfile(ztf_plot_path): - os.remove(ztf_plot_path) - continue - - # Write table to file: - tab.write(ztf_lightcurve_path, format='ascii.ecsv', delimiter=',') - - # Find time of maxmimum and 14 days from that: - indx_min = np.argmin(tab['mag']) - maximum_mjd = tab['time'][indx_min] - fortnight_mjd = maximum_mjd + 14 - - # Get LC data out and save as CSV files - fig, ax = plt.subplots() - ax.axvline(maximum_mjd, ls='--', c='k', lw=0.5, label='Maximum') - ax.axvline(fortnight_mjd, ls='--', c='0.5', lw=0.5, label='+14 days') - for fid in np.unique(tab['photfilter']): - col = colors[fid] - band = tab[tab['photfilter'] == fid] - ax.errorbar(band['time'], band['mag'], band['mag_err'], - color=col, ls='-', lw=0.5, marker='.', label=fid) - - ax.invert_yaxis() - ax.set_title(target_name) - ax.set_xlabel('Time (MJD)') - ax.set_ylabel('Magnitude') - ax.legend() - fig.savefig(ztf_plot_path, format='png', bbox_inches='tight') - plt.close(fig) - -#-------------------------------------------------------------------------------------------------- + # Parse command line arguments: + parser = argparse.ArgumentParser(description='Download ZTF photometry.') + parser.add_argument('-d', '--debug', help='Print debug messages.', action='store_true') + parser.add_argument('-q', '--quiet', help='Only report warnings and errors.', action='store_true') + parser.add_argument('-t', '--target', type=str, default=None, help='Target to download ZTF photometry for.') + parser.add_argument('-o', '--output', type=str, default=None, help='Directory to save output to.') + args = parser.parse_args() + + # Set logging level: + logging_level = logging.INFO + if args.quiet: + logging_level = logging.WARNING + elif args.debug: + logging_level = logging.DEBUG + + # Setup logging: + formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') + console = logging.StreamHandler() + console.setFormatter(formatter) + logger = logging.getLogger(__name__) + if not logger.hasHandlers(): + logger.addHandler(console) + logger.setLevel(logging_level) + + if args.output is None: + config = load_config() + output_dir = config.get('ztf', 'output_photometry', fallback='.') + else: + output_dir = args.output + logger.info("Saving output to '%s'", output_dir) + + # Check that output directory exists: + if not os.path.isdir(output_dir): + parser.error(f"Output directory does not exist: '{output_dir}'") # noqa: G004 + + # Use API to get list of targets to process: + if args.target is None: + targets = api.get_targets() + else: + targets = [api.get_target(args.target)] + + # Colors used for the different filters in plots: + # I know purple is in the wrong end of the scale, but not much I can do + colors = {'gp': 'tab:green', 'rp': 'tab:red', 'ip': 'tab:purple'} + + # Loop through targets: + for tgt in targets: + logger.debug("Target: %s", tgt) + target_name = tgt['target_name'] + + # Paths to the files to be updated: + ztf_lightcurve_path = os.path.join(output_dir, f'{target_name:s}-ztf.ecsv') + ztf_plot_path = os.path.join(output_dir, f'{target_name:s}-ztf.png') + + # If there is no ZTF id, there is no need to try: + # If an old file exists then delete it. + if tgt['ztf_id'] is None: + if os.path.isfile(ztf_lightcurve_path): + os.remove(ztf_lightcurve_path) + if os.path.isfile(ztf_plot_path): + os.remove(ztf_plot_path) + continue + + # Download ZTF photometry as Astropy Table: + tab = ztf.download_ztf_photometry(tgt['targetid']) + logger.debug("ZTF Photometry:\n%s", tab) + if tab is None or len(tab) == 0: + if os.path.isfile(ztf_lightcurve_path): + os.remove(ztf_lightcurve_path) + if os.path.isfile(ztf_plot_path): + os.remove(ztf_plot_path) + continue + + # Write table to file: + tab.write(ztf_lightcurve_path, format='ascii.ecsv', delimiter=',') + + # Find time of maxmimum and 14 days from that: + indx_min = np.argmin(tab['mag']) + maximum_mjd = tab['time'][indx_min] + fortnight_mjd = maximum_mjd + 14 + + # Get LC data out and save as CSV files + fig, ax = plt.subplots() + ax.axvline(maximum_mjd, ls='--', c='k', lw=0.5, label='Maximum') + ax.axvline(fortnight_mjd, ls='--', c='0.5', lw=0.5, label='+14 days') + for fid in np.unique(tab['photfilter']): + col = colors[fid] + band = tab[tab['photfilter'] == fid] + ax.errorbar(band['time'], band['mag'], band['mag_err'], color=col, ls='-', lw=0.5, marker='.', label=fid) + + ax.invert_yaxis() + ax.set_title(target_name) + ax.set_xlabel('Time (MJD)') + ax.set_ylabel('Magnitude') + ax.legend() + fig.savefig(ztf_plot_path, format='png', bbox_inches='tight') + plt.close(fig) + + +# -------------------------------------------------------------------------------------------------- if __name__ == '__main__': - main() + main() diff --git a/run_ingest.py b/run_ingest.py index 406f9ad..9e4cf9c 100644 --- a/run_ingest.py +++ b/run_ingest.py @@ -27,568 +27,564 @@ from flows.load_image import load_image from flows.utilities import get_filehash -#-------------------------------------------------------------------------------------------------- + +# -------------------------------------------------------------------------------------------------- def flows_get_archive_from_path(fname, archives_list=None): - """ - Translate full path into AADC archive identifier and relative path. - - It is highly recommended to provide the list with that call - to this function since it will involve a query to the database - at every call. - """ - - archive = None - relpath = None - - # Get list of archives, if not provided with call: - if archives_list is None: - with AADC_DB() as db: - db.cursor.execute("SELECT archive,path FROM aadc.files_archives;") - archives_list = db.cursor.fetchall() - - # Make sure folder is absolute path - folder = os.path.abspath(fname) - - # Loop through the defined archives and find one that matches: - for opt in archives_list: - archive_path = opt['path'] - if archive_path is not None and archive_path != '': - archive_path = archive_path.rstrip('/\\') + os.path.sep - if folder.startswith(archive_path): - archive = int(opt['archive']) - relpath = folder[len(archive_path):].replace('\\', '/') - break - - # We did not find anything: - if archive is None: - raise RuntimeError("File not in registred archive") - - return archive, relpath - -#-------------------------------------------------------------------------------------------------- + """ + Translate full path into AADC archive identifier and relative path. + + It is highly recommended to provide the list with that call + to this function since it will involve a query to the database + at every call. + """ + + archive = None + relpath = None + + # Get list of archives, if not provided with call: + if archives_list is None: + with AADC_DB() as db: + db.cursor.execute("SELECT archive,path FROM aadc.files_archives;") + archives_list = db.cursor.fetchall() + + # Make sure folder is absolute path + folder = os.path.abspath(fname) + + # Loop through the defined archives and find one that matches: + for opt in archives_list: + archive_path = opt['path'] + if archive_path is not None and archive_path != '': + archive_path = archive_path.rstrip('/\\') + os.path.sep + if folder.startswith(archive_path): + archive = int(opt['archive']) + relpath = folder[len(archive_path):].replace('\\', '/') + break + + # We did not find anything: + if archive is None: + raise RuntimeError("File not in registred archive") + + return archive, relpath + + +# -------------------------------------------------------------------------------------------------- def optipng(fpath): - os.system('optipng -preserve -quiet "%s"' % fpath) + os.system('optipng -preserve -quiet "%s"' % fpath) -#-------------------------------------------------------------------------------------------------- + +# -------------------------------------------------------------------------------------------------- class CounterFilter(logging.Filter): - """ - A logging filter which counts the number of log records in each level. - """ + """ + A logging filter which counts the number of log records in each level. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.counter = defaultdict(int) - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.counter = defaultdict(int) + def filter(self, record): # noqa: A003 + self.counter[record.levelname] += 1 + return True - def filter(self, record): # noqa: A003 - self.counter[record.levelname] += 1 - return True -#-------------------------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------------------------- def create_plot(filepath, target_coord=None, target_position=None): + output_fpath = os.path.abspath(re.sub(r'\.fits(\.gz)?$', '', filepath) + '.png') - output_fpath = os.path.abspath(re.sub(r'\.fits(\.gz)?$', '', filepath) + '.png') + img = load_image(filepath, target_coord=target_coord) - img = load_image(filepath, target_coord=target_coord) + fig = plt.figure(figsize=(12, 12)) + ax = fig.add_subplot(111) + plot_image(img.clean, ax=ax, scale='linear', percentile=[5, 99], cbar='right') + if target_position is not None: + ax.scatter(target_position[0], target_position[1], marker='+', s=20, c='r', label='Target') + fig.savefig(output_fpath, bbox_inches='tight') + plt.close(fig) - fig = plt.figure(figsize=(12,12)) - ax = fig.add_subplot(111) - plot_image(img.clean, ax=ax, scale='linear', percentile=[5, 99], cbar='right') - if target_position is not None: - ax.scatter(target_position[0], target_position[1], marker='+', s=20, c='r', label='Target') - fig.savefig(output_fpath, bbox_inches='tight') - plt.close(fig) + optipng(output_fpath) - optipng(output_fpath) -#-------------------------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------------------------- def ingest_from_inbox(): - - rootdir_inbox = '/flows/inbox' - rootdir = '/flows/archive' - - logger = logging.getLogger(__name__) - - # Check that root directories are available: - if not os.path.isdir(rootdir_inbox): - raise FileNotFoundError("INBOX does not exists") - if not os.path.isdir(rootdir): - raise FileNotFoundError("ARCHIVE does not exists") - - with AADC_DB() as db: - # Get list of archives: - db.cursor.execute("SELECT archive,path FROM aadc.files_archives;") - archives_list = db.cursor.fetchall() - - # Get list of all available filters: - db.cursor.execute("SELECT photfilter FROM flows.photfilters;") - all_filters = set([row['photfilter'] for row in db.cursor.fetchall()]) - - for inputtype in ('science', 'templates', 'subtracted', 'replace'): # - for fpath in glob.iglob(os.path.join(rootdir_inbox, '*', inputtype, '*')): - logger.info("="*72) - logger.info(fpath) - - # Find the uploadlog corresponding to this file: - db.cursor.execute("SELECT logid FROM flows.uploadlog WHERE uploadpath=%s;", [os.path.relpath(fpath, rootdir_inbox)]) - row = db.cursor.fetchone() - if row is not None: - uploadlogid = row['logid'] - else: - uploadlogid = None - logger.info("Uploadlog ID: %s", uploadlogid) - - # Only accept FITS file, or already compressed FITS files: - if not fpath.endswith('.fits') and not fpath.endswith('.fits.gz'): - if uploadlogid: - db.cursor.execute("UPDATE flows.uploadlog SET status='Invalid file type' WHERE logid=%s;", [uploadlogid]) - db.conn.commit() - logger.error("Invalid file type: %s", os.path.relpath(fpath, rootdir_inbox)) - continue - - # Get the name of the directory: - # Not pretty, but it works... - target_dirname = fpath[len(rootdir_inbox)+1:] - target_dirname = target_dirname.split(os.path.sep)[0] - - # Convert directory name to target - db.cursor.execute("SELECT targetid,target_name,ra,decl FROM flows.targets WHERE target_name=%s;", [target_dirname]) - row = db.cursor.fetchone() - if row is None: - logger.error('Could not find target: %s', target_dirname) - continue - targetid = row['targetid'] - targetname = row['target_name'] - target_radec = [[row['ra'], row['decl']]] - target_coord = coords.SkyCoord( - ra=row['ra'], - dec=row['decl'], - unit='deg', - frame='icrs') - - if not fpath.endswith('.gz'): - # Gzip the FITS file: - with open(fpath, 'rb') as f_in: - with gzip.open(fpath + '.gz', 'wb') as f_out: - f_out.writelines(f_in) - - # We should now have a Gzip file instead: - if os.path.isfile(fpath) and os.path.isfile(fpath + '.gz'): - # Update the log of this file: - if uploadlogid: - db.cursor.execute("UPDATE flows.uploadlog SET uploadpath=%s WHERE logid=%s;", [os.path.relpath(fpath+'.gz', rootdir_inbox), uploadlogid]) - db.conn.commit() - - os.remove(fpath) - fpath += '.gz' - else: - raise RuntimeError("Gzip file was not created correctly") - - version = 1 - if inputtype == 'science': - newpath = os.path.join(rootdir, targetname, os.path.basename(fpath)) - datatype = 1 - elif inputtype == 'templates': - newpath = os.path.join(rootdir, targetname, inputtype, os.path.basename(fpath)) - datatype = 3 - elif inputtype == 'subtracted': - newpath = os.path.join(rootdir, targetname, inputtype, os.path.basename(fpath)) - datatype = 4 - - original_fname = os.path.basename(fpath).replace('diff.fits', '.fits') - db.cursor.execute("SELECT fileid FROM flows.files WHERE targetid=%s AND datatype=1 AND path LIKE %s;", [targetid, '%/' + original_fname]) - subtracted_original_fileid = db.cursor.fetchone() - if subtracted_original_fileid is None: - if uploadlogid: - db.cursor.execute("UPDATE flows.uploadlog SET status='original science image not found' WHERE logid=%s;", [uploadlogid]) - db.conn.commit() - logger.error("ORIGINAL SCIENCE IMAGE COULD NOT BE FOUND: %s", os.path.basename(fpath)) - continue - else: - subtracted_original_fileid = subtracted_original_fileid[0] - - elif inputtype == 'replace': - bname = os.path.basename(fpath) - m = re.match(r'^(\d+)_v(\d+)\.fits(\.gz)?$', bname) - if m: - replaceid = int(m.group(1)) - version = int(m.group(2)) - - db.cursor.execute("SELECT datatype,path,version FROM flows.files WHERE fileid=%s;", [replaceid]) - row = db.cursor.fetchone() - if row is None: - logger.error("Unknown fileid to be replaced: %s", bname) - continue - datatype = row['datatype'] - subdir = {1: '', 4: 'subtracted'}[datatype] - - if version != row['version'] + 1: - logger.error("Mismatch in versions: old=%d, new=%d", row['version'], version) - continue - - newfilename = re.sub(r'(_v\d+)?\.fits(\.gz)?$', r'_v{version:d}.fits\2'.format(version=version), os.path.basename(row['path'])) - newpath = os.path.join(rootdir, targetname, subdir, newfilename) - - if datatype == 4: - db.cursor.execute("SELECT associd FROM flows.files_cross_assoc INNER JOIN flows.files ON files.fileid=files_cross_assoc.associd WHERE files_cross_assoc.fileid=%s AND datatype=1;", [replaceid]) - subtracted_original_fileid = db.cursor.fetchone() - if subtracted_original_fileid is None: - if uploadlogid: - db.cursor.execute("UPDATE flows.uploadlog SET status='original science image not found' WHERE logid=%s;", [uploadlogid]) - db.conn.commit() - logger.error("ORIGINAL SCIENCE IMAGE COULD NOT BE FOUND: %s", os.path.basename(fpath)) - continue - else: - subtracted_original_fileid = subtracted_original_fileid[0] - else: - logger.error("Invalid replace file name: %s", bname) - continue - - else: - raise RuntimeError("Not understood, Captain") - - logger.info(newpath) - - if os.path.exists(newpath): - logger.error("Already exists") - if uploadlogid: - db.cursor.execute("UPDATE flows.uploadlog SET status='Already exists: file name' WHERE logid=%s;", [uploadlogid]) - db.conn.commit() - continue - - archive, relpath = flows_get_archive_from_path(newpath, archives_list) - - db.cursor.execute("SELECT fileid FROM flows.files WHERE archive=%s AND path=%s;", [archive, relpath]) - if db.cursor.fetchone() is not None: - logger.error("ALREADY DONE") - continue - - # Calculate filehash of the file being stored: - filehash = get_filehash(fpath) - - # Check that the file does not already exist: - db.cursor.execute("SELECT fileid FROM flows.files WHERE filehash=%s;", [filehash]) - if db.cursor.fetchone() is not None: - logger.error("ALREADY DONE: Filehash") - if uploadlogid: - db.cursor.execute("UPDATE flows.uploadlog SET status='Already exists: filehash' WHERE logid=%s;", [uploadlogid]) - db.conn.commit() - continue - - # Try to load the image using the same function as the pipeline would: - try: - img = load_image(fpath, target_coord=target_coord) - except Exception as e: # pragma: no cover - logger.exception("Could not load FITS image") - if uploadlogid: - errmsg = str(e) if hasattr(e, 'message') else str(e.message) - db.cursor.execute("UPDATE flows.uploadlog SET status=%s WHERE logid=%s;", ['Load Image Error: ' + errmsg, uploadlogid]) - db.conn.commit() - continue - - # Use the WCS in the file to calculate the pixel-positon of the target: - try: - target_pixels = img.wcs.all_world2pix(target_radec, 0).flatten() - except: # noqa: E722, pragma: no cover - logger.exception("Could not find target position using the WCS.") - if uploadlogid: - errmsg = "Could not find target position using the WCS." - db.cursor.execute("UPDATE flows.uploadlog SET status=%s WHERE logid=%s;", [errmsg, uploadlogid]) - db.conn.commit() - continue - - # Check that the position of the target actually falls within - # the pixels of the image: - if target_pixels[0] < -0.5 or target_pixels[1] < -0.5 \ - or target_pixels[0] > img.shape[1]-0.5 or target_pixels[1] > img.shape[0]-0.5: - logger.error("Target position does not fall within image. Check the WCS.") - if uploadlogid: - errmsg = "Target position does not fall within image. Check the WCS." - db.cursor.execute("UPDATE flows.uploadlog SET status=%s WHERE logid=%s;", [errmsg, uploadlogid]) - db.conn.commit() - continue - - # Check that the site was found: - if img.site is None or img.site['siteid'] is None: - logger.error("Unknown SITE") - if uploadlogid: - errmsg = "Unknown site" - db.cursor.execute("UPDATE flows.uploadlog SET status=%s WHERE logid=%s;", [errmsg, uploadlogid]) - db.conn.commit() - continue - - # Check that the extracted photometric filter is valid: - if img.photfilter not in all_filters: - logger.error("Unknown PHOTFILTER: %s", img.photfilter) - if uploadlogid: - errmsg = "Unknown PHOTFILTER: '" + str(img.photfilter) + "'" - db.cursor.execute("UPDATE flows.uploadlog SET status=%s WHERE logid=%s;", [errmsg, uploadlogid]) - db.conn.commit() - continue - - # Do a deep check to ensure that there is not already another file with the same - # properties (target, datatype, site, filter) taken at the same time: - # TODO: Look at the actual overlap with the database, instead of just overlap - # with the central value. This way of doing it is more forgiving. - obstime = img.obstime.utc.mjd - if inputtype != 'replace': - db.cursor.execute("SELECT fileid FROM flows.files WHERE targetid=%s AND datatype=%s AND site=%s AND photfilter=%s AND obstime BETWEEN %s AND %s;", [ - targetid, - datatype, - img.site['siteid'], - img.photfilter, - obstime - 0.5 * img.exptime/86400, - obstime + 0.5 * img.exptime/86400, - ]) - if db.cursor.fetchone() is not None: - logger.error("ALREADY DONE: Deep check") - if uploadlogid: - db.cursor.execute("UPDATE flows.uploadlog SET status='Already exists: deep check' WHERE logid=%s;", [uploadlogid]) - db.conn.commit() - continue - - try: - # Copy the file to its new home: - os.makedirs(os.path.dirname(newpath), exist_ok=True) - shutil.copy(fpath, newpath) - - # Set file and directory permissions: - # TODO: Can this not be handled in a more elegant way? - os.chmod(os.path.dirname(newpath), 0o2750) - os.chmod(newpath, 0o0440) - - filesize = os.path.getsize(fpath) - - if not fpath.endswith('-e00.fits'): - create_plot(newpath, target_coord=target_coord, target_position=target_pixels) - - db.cursor.execute("INSERT INTO flows.files (archive,path,targetid,datatype,site,filesize,filehash,obstime,photfilter,exptime,version,available) VALUES (%(archive)s,%(relpath)s,%(targetid)s,%(datatype)s,%(site)s,%(filesize)s,%(filehash)s,%(obstime)s,%(photfilter)s,%(exptime)s,%(version)s,1) RETURNING fileid;", { - 'archive': archive, - 'relpath': relpath, - 'targetid': targetid, - 'datatype': datatype, - 'site': img.site['siteid'], - 'filesize': filesize, - 'filehash': filehash, - 'obstime': obstime, - 'photfilter': img.photfilter, - 'exptime': img.exptime, - 'version': version - }) - fileid = db.cursor.fetchone()[0] - - if datatype == 4: - db.cursor.execute("INSERT INTO flows.files_cross_assoc (fileid,associd) VALUES (%s,%s);", [fileid, subtracted_original_fileid]) - - if inputtype == 'replace': - db.cursor.execute("UPDATE flows.files SET newest_version=FALSE WHERE fileid=%s;", [replaceid]) - - if uploadlogid: - db.cursor.execute("UPDATE flows.uploadlog SET fileid=%s,status='ok' WHERE logid=%s;", [fileid, uploadlogid]) - - db.conn.commit() - except: # noqa: E722, pragma: no cover - db.conn.rollback() - if os.path.exists(newpath): - os.remove(newpath) - raise - else: - logger.info("DELETE THE ORIGINAL FILE") - if os.path.isfile(newpath): - os.remove(fpath) - if uploadlogid: - db.cursor.execute("UPDATE flows.uploadlog SET uploadpath=NULL WHERE logid=%s;", [uploadlogid]) - db.conn.commit() - -#-------------------------------------------------------------------------------------------------- + rootdir_inbox = '/flows/inbox' + rootdir = '/flows/archive' + + logger = logging.getLogger(__name__) + + # Check that root directories are available: + if not os.path.isdir(rootdir_inbox): + raise FileNotFoundError("INBOX does not exists") + if not os.path.isdir(rootdir): + raise FileNotFoundError("ARCHIVE does not exists") + + with AADC_DB() as db: + # Get list of archives: + db.cursor.execute("SELECT archive,path FROM aadc.files_archives;") + archives_list = db.cursor.fetchall() + + # Get list of all available filters: + db.cursor.execute("SELECT photfilter FROM flows.photfilters;") + all_filters = set([row['photfilter'] for row in db.cursor.fetchall()]) + + for inputtype in ('science', 'templates', 'subtracted', 'replace'): # + for fpath in glob.iglob(os.path.join(rootdir_inbox, '*', inputtype, '*')): + logger.info("=" * 72) + logger.info(fpath) + + # Find the uploadlog corresponding to this file: + db.cursor.execute("SELECT logid FROM flows.uploadlog WHERE uploadpath=%s;", + [os.path.relpath(fpath, rootdir_inbox)]) + row = db.cursor.fetchone() + if row is not None: + uploadlogid = row['logid'] + else: + uploadlogid = None + logger.info("Uploadlog ID: %s", uploadlogid) + + # Only accept FITS file, or already compressed FITS files: + if not fpath.endswith('.fits') and not fpath.endswith('.fits.gz'): + if uploadlogid: + db.cursor.execute("UPDATE flows.uploadlog SET status='Invalid file type' WHERE logid=%s;", + [uploadlogid]) + db.conn.commit() + logger.error("Invalid file type: %s", os.path.relpath(fpath, rootdir_inbox)) + continue + + # Get the name of the directory: + # Not pretty, but it works... + target_dirname = fpath[len(rootdir_inbox) + 1:] + target_dirname = target_dirname.split(os.path.sep)[0] + + # Convert directory name to target + db.cursor.execute("SELECT targetid,target_name,ra,decl FROM flows.targets WHERE target_name=%s;", + [target_dirname]) + row = db.cursor.fetchone() + if row is None: + logger.error('Could not find target: %s', target_dirname) + continue + targetid = row['targetid'] + targetname = row['target_name'] + target_radec = [[row['ra'], row['decl']]] + target_coord = coords.SkyCoord(ra=row['ra'], dec=row['decl'], unit='deg', frame='icrs') + + if not fpath.endswith('.gz'): + # Gzip the FITS file: + with open(fpath, 'rb') as f_in: + with gzip.open(fpath + '.gz', 'wb') as f_out: + f_out.writelines(f_in) + + # We should now have a Gzip file instead: + if os.path.isfile(fpath) and os.path.isfile(fpath + '.gz'): + # Update the log of this file: + if uploadlogid: + db.cursor.execute("UPDATE flows.uploadlog SET uploadpath=%s WHERE logid=%s;", + [os.path.relpath(fpath + '.gz', rootdir_inbox), uploadlogid]) + db.conn.commit() + + os.remove(fpath) + fpath += '.gz' + else: + raise RuntimeError("Gzip file was not created correctly") + + version = 1 + if inputtype == 'science': + newpath = os.path.join(rootdir, targetname, os.path.basename(fpath)) + datatype = 1 + elif inputtype == 'templates': + newpath = os.path.join(rootdir, targetname, inputtype, os.path.basename(fpath)) + datatype = 3 + elif inputtype == 'subtracted': + newpath = os.path.join(rootdir, targetname, inputtype, os.path.basename(fpath)) + datatype = 4 + + original_fname = os.path.basename(fpath).replace('diff.fits', '.fits') + db.cursor.execute( + "SELECT fileid FROM flows.files WHERE targetid=%s AND datatype=1 AND path LIKE %s;", + [targetid, '%/' + original_fname]) + subtracted_original_fileid = db.cursor.fetchone() + if subtracted_original_fileid is None: + if uploadlogid: + db.cursor.execute( + "UPDATE flows.uploadlog SET status='original science image not found' WHERE logid=%s;", + [uploadlogid]) + db.conn.commit() + logger.error("ORIGINAL SCIENCE IMAGE COULD NOT BE FOUND: %s", os.path.basename(fpath)) + continue + else: + subtracted_original_fileid = subtracted_original_fileid[0] + + elif inputtype == 'replace': + bname = os.path.basename(fpath) + m = re.match(r'^(\d+)_v(\d+)\.fits(\.gz)?$', bname) + if m: + replaceid = int(m.group(1)) + version = int(m.group(2)) + + db.cursor.execute("SELECT datatype,path,version FROM flows.files WHERE fileid=%s;", [replaceid]) + row = db.cursor.fetchone() + if row is None: + logger.error("Unknown fileid to be replaced: %s", bname) + continue + datatype = row['datatype'] + subdir = {1: '', 4: 'subtracted'}[datatype] + + if version != row['version'] + 1: + logger.error("Mismatch in versions: old=%d, new=%d", row['version'], version) + continue + + newfilename = re.sub(r'(_v\d+)?\.fits(\.gz)?$', r'_v{version:d}.fits\2'.format(version=version), + os.path.basename(row['path'])) + newpath = os.path.join(rootdir, targetname, subdir, newfilename) + + if datatype == 4: + db.cursor.execute( + "SELECT associd FROM flows.files_cross_assoc INNER JOIN flows.files ON files.fileid=files_cross_assoc.associd WHERE files_cross_assoc.fileid=%s AND datatype=1;", + [replaceid]) + subtracted_original_fileid = db.cursor.fetchone() + if subtracted_original_fileid is None: + if uploadlogid: + db.cursor.execute( + "UPDATE flows.uploadlog SET status='original science image not found' WHERE logid=%s;", + [uploadlogid]) + db.conn.commit() + logger.error("ORIGINAL SCIENCE IMAGE COULD NOT BE FOUND: %s", os.path.basename(fpath)) + continue + else: + subtracted_original_fileid = subtracted_original_fileid[0] + else: + logger.error("Invalid replace file name: %s", bname) + continue + + else: + raise RuntimeError("Not understood, Captain") + + logger.info(newpath) + + if os.path.exists(newpath): + logger.error("Already exists") + if uploadlogid: + db.cursor.execute( + "UPDATE flows.uploadlog SET status='Already exists: file name' WHERE logid=%s;", + [uploadlogid]) + db.conn.commit() + continue + + archive, relpath = flows_get_archive_from_path(newpath, archives_list) + + db.cursor.execute("SELECT fileid FROM flows.files WHERE archive=%s AND path=%s;", [archive, relpath]) + if db.cursor.fetchone() is not None: + logger.error("ALREADY DONE") + continue + + # Calculate filehash of the file being stored: + filehash = get_filehash(fpath) + + # Check that the file does not already exist: + db.cursor.execute("SELECT fileid FROM flows.files WHERE filehash=%s;", [filehash]) + if db.cursor.fetchone() is not None: + logger.error("ALREADY DONE: Filehash") + if uploadlogid: + db.cursor.execute( + "UPDATE flows.uploadlog SET status='Already exists: filehash' WHERE logid=%s;", + [uploadlogid]) + db.conn.commit() + continue + + # Try to load the image using the same function as the pipeline would: + try: + img = load_image(fpath, target_coord=target_coord) + except Exception as e: # pragma: no cover + logger.exception("Could not load FITS image") + if uploadlogid: + errmsg = str(e) if hasattr(e, 'message') else str(e.message) + db.cursor.execute("UPDATE flows.uploadlog SET status=%s WHERE logid=%s;", + ['Load Image Error: ' + errmsg, uploadlogid]) + db.conn.commit() + continue + + # Use the WCS in the file to calculate the pixel-positon of the target: + try: + target_pixels = img.wcs.all_world2pix(target_radec, 0).flatten() + except: # noqa: E722, pragma: no cover + logger.exception("Could not find target position using the WCS.") + if uploadlogid: + errmsg = "Could not find target position using the WCS." + db.cursor.execute("UPDATE flows.uploadlog SET status=%s WHERE logid=%s;", [errmsg, uploadlogid]) + db.conn.commit() + continue + + # Check that the position of the target actually falls within + # the pixels of the image: + if target_pixels[0] < -0.5 or target_pixels[1] < -0.5 or target_pixels[0] > img.shape[1] - 0.5 or \ + target_pixels[1] > img.shape[0] - 0.5: + logger.error("Target position does not fall within image. Check the WCS.") + if uploadlogid: + errmsg = "Target position does not fall within image. Check the WCS." + db.cursor.execute("UPDATE flows.uploadlog SET status=%s WHERE logid=%s;", [errmsg, uploadlogid]) + db.conn.commit() + continue + + # Check that the site was found: + if img.site is None or img.site['siteid'] is None: + logger.error("Unknown SITE") + if uploadlogid: + errmsg = "Unknown site" + db.cursor.execute("UPDATE flows.uploadlog SET status=%s WHERE logid=%s;", [errmsg, uploadlogid]) + db.conn.commit() + continue + + # Check that the extracted photometric filter is valid: + if img.photfilter not in all_filters: + logger.error("Unknown PHOTFILTER: %s", img.photfilter) + if uploadlogid: + errmsg = "Unknown PHOTFILTER: '" + str(img.photfilter) + "'" + db.cursor.execute("UPDATE flows.uploadlog SET status=%s WHERE logid=%s;", [errmsg, uploadlogid]) + db.conn.commit() + continue + + # Do a deep check to ensure that there is not already another file with the same + # properties (target, datatype, site, filter) taken at the same time: + # TODO: Look at the actual overlap with the database, instead of just overlap + # with the central value. This way of doing it is more forgiving. + obstime = img.obstime.utc.mjd + if inputtype != 'replace': + db.cursor.execute( + "SELECT fileid FROM flows.files WHERE targetid=%s AND datatype=%s AND site=%s AND photfilter=%s AND obstime BETWEEN %s AND %s;", + [targetid, datatype, img.site['siteid'], img.photfilter, obstime - 0.5 * img.exptime / 86400, + obstime + 0.5 * img.exptime / 86400, ]) + if db.cursor.fetchone() is not None: + logger.error("ALREADY DONE: Deep check") + if uploadlogid: + db.cursor.execute( + "UPDATE flows.uploadlog SET status='Already exists: deep check' WHERE logid=%s;", + [uploadlogid]) + db.conn.commit() + continue + + try: + # Copy the file to its new home: + os.makedirs(os.path.dirname(newpath), exist_ok=True) + shutil.copy(fpath, newpath) + + # Set file and directory permissions: + # TODO: Can this not be handled in a more elegant way? + os.chmod(os.path.dirname(newpath), 0o2750) + os.chmod(newpath, 0o0440) + + filesize = os.path.getsize(fpath) + + if not fpath.endswith('-e00.fits'): + create_plot(newpath, target_coord=target_coord, target_position=target_pixels) + + db.cursor.execute( + "INSERT INTO flows.files (archive,path,targetid,datatype,site,filesize,filehash,obstime,photfilter,exptime,version,available) VALUES (%(archive)s,%(relpath)s,%(targetid)s,%(datatype)s,%(site)s,%(filesize)s,%(filehash)s,%(obstime)s,%(photfilter)s,%(exptime)s,%(version)s,1) RETURNING fileid;", + {'archive': archive, 'relpath': relpath, 'targetid': targetid, 'datatype': datatype, + 'site': img.site['siteid'], 'filesize': filesize, 'filehash': filehash, 'obstime': obstime, + 'photfilter': img.photfilter, 'exptime': img.exptime, 'version': version}) + fileid = db.cursor.fetchone()[0] + + if datatype == 4: + db.cursor.execute("INSERT INTO flows.files_cross_assoc (fileid,associd) VALUES (%s,%s);", + [fileid, subtracted_original_fileid]) + + if inputtype == 'replace': + db.cursor.execute("UPDATE flows.files SET newest_version=FALSE WHERE fileid=%s;", [replaceid]) + + if uploadlogid: + db.cursor.execute("UPDATE flows.uploadlog SET fileid=%s,status='ok' WHERE logid=%s;", + [fileid, uploadlogid]) + + db.conn.commit() + except: # noqa: E722, pragma: no cover + db.conn.rollback() + if os.path.exists(newpath): + os.remove(newpath) + raise + else: + logger.info("DELETE THE ORIGINAL FILE") + if os.path.isfile(newpath): + os.remove(fpath) + if uploadlogid: + db.cursor.execute("UPDATE flows.uploadlog SET uploadpath=NULL WHERE logid=%s;", [uploadlogid]) + db.conn.commit() + + +# -------------------------------------------------------------------------------------------------- def ingest_photometry_from_inbox(): - - rootdir_inbox = '/flows/inbox' - rootdir_archive = '/flows/archive_photometry' - - logger = logging.getLogger(__name__) - - # Check that root directories are available: - if not os.path.isdir(rootdir_inbox): - raise FileNotFoundError("INBOX does not exists") - if not os.path.isdir(rootdir_archive): - raise FileNotFoundError("ARCHIVE does not exists") - - with AADC_DB() as db: - # Get list of archives: - db.cursor.execute("SELECT archive,path FROM aadc.files_archives;") - archives_list = db.cursor.fetchall() - - for fpath in glob.iglob(os.path.join(rootdir_inbox, '*', 'photometry', '*')): - logger.info("="*72) - logger.info(fpath) - - # Find the uploadlog corresponding to this file: - db.cursor.execute("SELECT logid FROM flows.uploadlog WHERE uploadpath=%s;", [os.path.relpath(fpath, rootdir_inbox)]) - row = db.cursor.fetchone() - if row is not None: - uploadlogid = row['logid'] - else: - uploadlogid = None - - # Only accept FITS file, or already compressed FITS files: - if not fpath.endswith('.zip'): - if uploadlogid: - db.cursor.execute("UPDATE flows.uploadlog SET status='Invalid file type' WHERE logid=%s;", [uploadlogid]) - db.conn.commit() - logger.error("Invalid file type: %s", os.path.relpath(fpath, rootdir_inbox)) - continue - - # Get the name of the directory: - # Not pretty, but it works... - target_dirname = fpath[len(rootdir_inbox)+1:] - target_dirname = target_dirname.split(os.path.sep)[0] - - # Convert directory name to target - db.cursor.execute("SELECT targetid,target_name FROM flows.targets WHERE target_name=%s;", [target_dirname]) - row = db.cursor.fetchone() - if row is None: - logger.error('Could not find target: %s', target_dirname) - continue - targetid = row['targetid'] - targetname = row['target_name'] - - newpath = None - try: - with tempfile.TemporaryDirectory() as tmpdir: - # - tmpphotfile = os.path.join(tmpdir, 'photometry.ecsv') - - # Extract the ZIP file: - with ZipFile(fpath, mode='r') as myzip: - for member in myzip.infolist(): - # Remove any directory structure from the zip file: - if member.filename.endswith('/'): # member.is_dir() - continue - member.filename = os.path.basename(member.filename) - - # Due to security considerations, we only allow specific files - # to be extracted: - if member.filename == 'photometry.ecsv': - myzip.extract(member, path=tmpdir) - elif member.filename.endswith('.png') or member.filename.endswith('.log'): - myzip.extract(member, path=tmpdir) - - # Check that the photometry ECSV file at least exists: - if not os.path.isfile(tmpphotfile): - raise FileNotFoundError("Photometry is not found") - - # Load photometry table: - tab = Table.read(tmpphotfile, format='ascii.ecsv') - fileid_img = int(tab.meta['fileid']) - targetid_table = int(tab.meta['targetid']) - - assert targetid_table == targetid - - # Find out which version number to assign to file: - db.cursor.execute("SELECT MAX(files.version) AS latest_version FROM flows.files_cross_assoc fca INNER JOIN flows.files ON fca.fileid=files.fileid WHERE fca.associd=%s AND files.datatype=2;", [fileid_img,]) - latest_version = db.cursor.fetchone() - if latest_version[0] is None: - new_version = 1 - else: - new_version = latest_version[0] + 1 - - # Create a new path and filename that is slightly more descriptive: - newpath = os.path.join( - rootdir_archive, - targetname, - f'{fileid_img:05d}', - f'v{new_version:02d}', - f'photometry-{targetname:s}-{fileid_img:05d}-v{new_version:02d}.ecsv' - ) - logger.info(newpath) - - if os.path.exists(newpath): - logger.error("Already exists") - if uploadlogid: - db.cursor.execute("UPDATE flows.uploadlog SET status='Already exists: file name' WHERE logid=%s;", [uploadlogid]) - db.conn.commit() - continue - - archive, relpath = flows_get_archive_from_path(newpath, archives_list) - - db.cursor.execute("SELECT fileid FROM flows.files WHERE archive=%s AND path=%s;", [archive, relpath]) - if db.cursor.fetchone() is not None: - logger.error("ALREADY DONE") - continue - - db.cursor.execute("SELECT * FROM flows.files WHERE fileid=%s;", [fileid_img]) - row = db.cursor.fetchone() - site = row['site'] - - assert targetid == row['targetid'] - assert tab.meta['photfilter'] == row['photfilter'] - - # Optimize all the PNG files in the temp directory: - for f in glob.iglob(os.path.join(tmpdir, '*.png')): - optipng(f) - - # Copy the full directory to its new home: - shutil.copytree(tmpdir, os.path.dirname(newpath)) - os.rename(os.path.join(os.path.dirname(newpath), 'photometry.ecsv'), newpath) - - # Get information about file: - filesize = os.path.getsize(newpath) - filehash = get_filehash(newpath) - - db.cursor.execute("INSERT INTO flows.files (archive,path,targetid,datatype,site,filesize,filehash,obstime,photfilter,version,available) VALUES (%(archive)s,%(relpath)s,%(targetid)s,%(datatype)s,%(site)s,%(filesize)s,%(filehash)s,%(obstime)s,%(photfilter)s,%(version)s,1) RETURNING fileid;", { - 'archive': archive, - 'relpath': relpath, - 'targetid': targetid, - 'datatype': 2, - 'site': site, - 'filesize': filesize, - 'filehash': filehash, - 'obstime': tab.meta['obstime-bmjd'], - 'photfilter': tab.meta['photfilter'], - 'version': new_version - }) - fileid = db.cursor.fetchone()[0] - - # Add dependencies: - db.cursor.execute("INSERT INTO flows.files_cross_assoc (fileid,associd) VALUES (%s,%s);", [fileid, fileid_img]) - if tab.meta['template'] is not None: - db.cursor.execute("INSERT INTO flows.files_cross_assoc (fileid,associd) VALUES (%s,%s);", [fileid, tab.meta['template']]) - if tab.meta['diffimg'] is not None: - db.cursor.execute("INSERT INTO flows.files_cross_assoc (fileid,associd) VALUES (%s,%s);", [fileid, tab.meta['diffimg']]) - - indx_raw = (tab['starid'] == 0) - indx_sub = (tab['starid'] == -1) - indx_ref = (tab['starid'] > 0) - - frd = float(np.nanmax(tab[indx_ref]['mag'])) - if not np.isfinite(frd): - frd = None - - phot_summary = { - 'fileid_img': fileid_img, - 'fileid_phot': fileid, - 'fileid_template': tab.meta['template'], - 'fileid_diffimg': tab.meta['diffimg'], - 'targetid': targetid, - 'obstime': tab.meta['obstime-bmjd'], - 'photfilter': tab.meta['photfilter'], - 'mag_raw': float(tab[indx_raw]['mag']), - 'mag_raw_error': float(tab[indx_raw]['mag_error']), - 'mag_sub': None if not any(indx_sub) else float(tab[indx_sub]['mag']), - 'mag_sub_error': None if not any(indx_sub) else float(tab[indx_sub]['mag_error']), - 'zeropoint': float(tab.meta['zp']), - 'zeropoint_error': float(tab.meta['zp_error']), - 'zeropoint_diff': float(tab.meta['zp_diff']), - 'fwhm': float(tab.meta['fwhm'].value), - 'seeing': float(tab.meta['seeing'].value), - 'references_detected': int(np.sum(indx_ref)), - 'used_for_epsf': int(np.sum(tab['used_for_epsf'])), - 'faintest_reference_detected': frd, - 'pipeline_version': tab.meta['version'], - 'latest_version': new_version - } - - db.cursor.execute("""INSERT INTO flows.photometry_details ( + rootdir_inbox = '/flows/inbox' + rootdir_archive = '/flows/archive_photometry' + + logger = logging.getLogger(__name__) + + # Check that root directories are available: + if not os.path.isdir(rootdir_inbox): + raise FileNotFoundError("INBOX does not exists") + if not os.path.isdir(rootdir_archive): + raise FileNotFoundError("ARCHIVE does not exists") + + with AADC_DB() as db: + # Get list of archives: + db.cursor.execute("SELECT archive,path FROM aadc.files_archives;") + archives_list = db.cursor.fetchall() + + for fpath in glob.iglob(os.path.join(rootdir_inbox, '*', 'photometry', '*')): + logger.info("=" * 72) + logger.info(fpath) + + # Find the uploadlog corresponding to this file: + db.cursor.execute("SELECT logid FROM flows.uploadlog WHERE uploadpath=%s;", + [os.path.relpath(fpath, rootdir_inbox)]) + row = db.cursor.fetchone() + if row is not None: + uploadlogid = row['logid'] + else: + uploadlogid = None + + # Only accept FITS file, or already compressed FITS files: + if not fpath.endswith('.zip'): + if uploadlogid: + db.cursor.execute("UPDATE flows.uploadlog SET status='Invalid file type' WHERE logid=%s;", + [uploadlogid]) + db.conn.commit() + logger.error("Invalid file type: %s", os.path.relpath(fpath, rootdir_inbox)) + continue + + # Get the name of the directory: + # Not pretty, but it works... + target_dirname = fpath[len(rootdir_inbox) + 1:] + target_dirname = target_dirname.split(os.path.sep)[0] + + # Convert directory name to target + db.cursor.execute("SELECT targetid,target_name FROM flows.targets WHERE target_name=%s;", [target_dirname]) + row = db.cursor.fetchone() + if row is None: + logger.error('Could not find target: %s', target_dirname) + continue + targetid = row['targetid'] + targetname = row['target_name'] + + newpath = None + try: + with tempfile.TemporaryDirectory() as tmpdir: + # + tmpphotfile = os.path.join(tmpdir, 'photometry.ecsv') + + # Extract the ZIP file: + with ZipFile(fpath, mode='r') as myzip: + for member in myzip.infolist(): + # Remove any directory structure from the zip file: + if member.filename.endswith('/'): # member.is_dir() + continue + member.filename = os.path.basename(member.filename) + + # Due to security considerations, we only allow specific files + # to be extracted: + if member.filename == 'photometry.ecsv': + myzip.extract(member, path=tmpdir) + elif member.filename.endswith('.png') or member.filename.endswith('.log'): + myzip.extract(member, path=tmpdir) + + # Check that the photometry ECSV file at least exists: + if not os.path.isfile(tmpphotfile): + raise FileNotFoundError("Photometry is not found") + + # Load photometry table: + tab = Table.read(tmpphotfile, format='ascii.ecsv') + fileid_img = int(tab.meta['fileid']) + targetid_table = int(tab.meta['targetid']) + + assert targetid_table == targetid + + # Find out which version number to assign to file: + db.cursor.execute( + "SELECT MAX(files.version) AS latest_version FROM flows.files_cross_assoc fca INNER JOIN flows.files ON fca.fileid=files.fileid WHERE fca.associd=%s AND files.datatype=2;", + [fileid_img, ]) + latest_version = db.cursor.fetchone() + if latest_version[0] is None: + new_version = 1 + else: + new_version = latest_version[0] + 1 + + # Create a new path and filename that is slightly more descriptive: + newpath = os.path.join(rootdir_archive, targetname, f'{fileid_img:05d}', f'v{new_version:02d}', + f'photometry-{targetname:s}-{fileid_img:05d}-v{new_version:02d}.ecsv') + logger.info(newpath) + + if os.path.exists(newpath): + logger.error("Already exists") + if uploadlogid: + db.cursor.execute( + "UPDATE flows.uploadlog SET status='Already exists: file name' WHERE logid=%s;", + [uploadlogid]) + db.conn.commit() + continue + + archive, relpath = flows_get_archive_from_path(newpath, archives_list) + + db.cursor.execute("SELECT fileid FROM flows.files WHERE archive=%s AND path=%s;", + [archive, relpath]) + if db.cursor.fetchone() is not None: + logger.error("ALREADY DONE") + continue + + db.cursor.execute("SELECT * FROM flows.files WHERE fileid=%s;", [fileid_img]) + row = db.cursor.fetchone() + site = row['site'] + + assert targetid == row['targetid'] + assert tab.meta['photfilter'] == row['photfilter'] + + # Optimize all the PNG files in the temp directory: + for f in glob.iglob(os.path.join(tmpdir, '*.png')): + optipng(f) + + # Copy the full directory to its new home: + shutil.copytree(tmpdir, os.path.dirname(newpath)) + os.rename(os.path.join(os.path.dirname(newpath), 'photometry.ecsv'), newpath) + + # Get information about file: + filesize = os.path.getsize(newpath) + filehash = get_filehash(newpath) + + db.cursor.execute( + "INSERT INTO flows.files (archive,path,targetid,datatype,site,filesize,filehash,obstime,photfilter,version,available) VALUES (%(archive)s,%(relpath)s,%(targetid)s,%(datatype)s,%(site)s,%(filesize)s,%(filehash)s,%(obstime)s,%(photfilter)s,%(version)s,1) RETURNING fileid;", + {'archive': archive, 'relpath': relpath, 'targetid': targetid, 'datatype': 2, 'site': site, + 'filesize': filesize, 'filehash': filehash, 'obstime': tab.meta['obstime-bmjd'], + 'photfilter': tab.meta['photfilter'], 'version': new_version}) + fileid = db.cursor.fetchone()[0] + + # Add dependencies: + db.cursor.execute("INSERT INTO flows.files_cross_assoc (fileid,associd) VALUES (%s,%s);", + [fileid, fileid_img]) + if tab.meta['template'] is not None: + db.cursor.execute("INSERT INTO flows.files_cross_assoc (fileid,associd) VALUES (%s,%s);", + [fileid, tab.meta['template']]) + if tab.meta['diffimg'] is not None: + db.cursor.execute("INSERT INTO flows.files_cross_assoc (fileid,associd) VALUES (%s,%s);", + [fileid, tab.meta['diffimg']]) + + indx_raw = (tab['starid'] == 0) + indx_sub = (tab['starid'] == -1) + indx_ref = (tab['starid'] > 0) + + frd = float(np.nanmax(tab[indx_ref]['mag'])) + if not np.isfinite(frd): + frd = None + + phot_summary = {'fileid_img': fileid_img, 'fileid_phot': fileid, + 'fileid_template': tab.meta['template'], 'fileid_diffimg': tab.meta['diffimg'], + 'targetid': targetid, 'obstime': tab.meta['obstime-bmjd'], + 'photfilter': tab.meta['photfilter'], 'mag_raw': float(tab[indx_raw]['mag']), + 'mag_raw_error': float(tab[indx_raw]['mag_error']), + 'mag_sub': None if not any(indx_sub) else float(tab[indx_sub]['mag']), + 'mag_sub_error': None if not any(indx_sub) else float(tab[indx_sub]['mag_error']), + 'zeropoint': float(tab.meta['zp']), 'zeropoint_error': float(tab.meta['zp_error']), + 'zeropoint_diff': float(tab.meta['zp_diff']), 'fwhm': float(tab.meta['fwhm'].value), + 'seeing': float(tab.meta['seeing'].value), 'references_detected': int(np.sum(indx_ref)), + 'used_for_epsf': int(np.sum(tab['used_for_epsf'])), 'faintest_reference_detected': frd, + 'pipeline_version': tab.meta['version'], 'latest_version': new_version} + + db.cursor.execute("""INSERT INTO flows.photometry_details ( fileid_phot, fileid_img, fileid_template, @@ -628,9 +624,9 @@ def ingest_photometry_from_inbox(): %(pipeline_version)s );""", phot_summary) - db.cursor.execute("SELECT * FROM flows.photometry_summary WHERE fileid_img=%s;", [fileid_img]) - if db.cursor.fetchone() is None: - db.cursor.execute("""INSERT INTO flows.photometry_summary ( + db.cursor.execute("SELECT * FROM flows.photometry_summary WHERE fileid_img=%s;", [fileid_img]) + if db.cursor.fetchone() is None: + db.cursor.execute("""INSERT INTO flows.photometry_summary ( fileid_phot, fileid_img, fileid_template, @@ -659,8 +655,8 @@ def ingest_photometry_from_inbox(): %(pipeline_version)s, %(latest_version)s );""", phot_summary) - else: - db.cursor.execute("""UPDATE flows.photometry_summary SET + else: + db.cursor.execute("""UPDATE flows.photometry_summary SET fileid_phot=%(fileid_phot)s, targetid=%(targetid)s, fileid_template=%(fileid_template)s, @@ -675,93 +671,99 @@ def ingest_photometry_from_inbox(): latest_version=%(latest_version)s WHERE fileid_img=%(fileid_img)s;""", phot_summary) - # Update the photometry status to done: - db.cursor.execute("UPDATE flows.photometry_status SET status='done' WHERE fileid=%(fileid_img)s AND status='ingest';", phot_summary) - - if uploadlogid: - db.cursor.execute("UPDATE flows.uploadlog SET fileid=%s,status='ok' WHERE logid=%s;", [fileid, uploadlogid]) - - db.conn.commit() - - except: # noqa: E722, pragma: no cover - db.conn.rollback() - if newpath is not None and os.path.isdir(os.path.dirname(newpath)): - shutil.rmtree(os.path.dirname(newpath)) - raise - else: - # Set file and directory permissions: - # TODO: Can this not be handled in a more elegant way? - os.chmod(os.path.join(rootdir_archive, targetname), 0o2750) - os.chmod(os.path.join(rootdir_archive, targetname, f'{fileid_img:05d}'), 0o2750) - os.chmod(os.path.join(rootdir_archive, targetname, f'{fileid_img:05d}', f'v{new_version:02d}'), 0o2550) - for f in os.listdir(os.path.dirname(newpath)): - os.chmod(os.path.join(os.path.dirname(newpath), f), 0o0440) - - logger.info("DELETE THE ORIGINAL FILE") - if os.path.isfile(fpath): - os.remove(fpath) - if uploadlogid: - db.cursor.execute("UPDATE flows.uploadlog SET uploadpath=NULL WHERE logid=%s;", [uploadlogid]) - db.conn.commit() - -#-------------------------------------------------------------------------------------------------- + # Update the photometry status to done: + db.cursor.execute( + "UPDATE flows.photometry_status SET status='done' WHERE fileid=%(fileid_img)s AND status='ingest';", + phot_summary) + + if uploadlogid: + db.cursor.execute("UPDATE flows.uploadlog SET fileid=%s,status='ok' WHERE logid=%s;", + [fileid, uploadlogid]) + + db.conn.commit() + + except: # noqa: E722, pragma: no cover + db.conn.rollback() + if newpath is not None and os.path.isdir(os.path.dirname(newpath)): + shutil.rmtree(os.path.dirname(newpath)) + raise + else: + # Set file and directory permissions: + # TODO: Can this not be handled in a more elegant way? + os.chmod(os.path.join(rootdir_archive, targetname), 0o2750) + os.chmod(os.path.join(rootdir_archive, targetname, f'{fileid_img:05d}'), 0o2750) + os.chmod(os.path.join(rootdir_archive, targetname, f'{fileid_img:05d}', f'v{new_version:02d}'), 0o2550) + for f in os.listdir(os.path.dirname(newpath)): + os.chmod(os.path.join(os.path.dirname(newpath), f), 0o0440) + + logger.info("DELETE THE ORIGINAL FILE") + if os.path.isfile(fpath): + os.remove(fpath) + if uploadlogid: + db.cursor.execute("UPDATE flows.uploadlog SET uploadpath=NULL WHERE logid=%s;", [uploadlogid]) + db.conn.commit() + + +# -------------------------------------------------------------------------------------------------- def cleanup_inbox(): - """ - Cleanup of inbox directory - """ - rootdir_inbox = '/flows/inbox' - - # Just a simple check to begin with: - if not os.path.isdir(rootdir_inbox): - raise FileNotFoundError("INBOX could not be found.") - - # Remove empty directories: - for inputtype in ('science', 'templates', 'subtracted', 'photometry', 'replace'): - for dpath in glob.iglob(os.path.join(rootdir_inbox, '*', inputtype)): - if not os.listdir(dpath): - os.rmdir(dpath) - - for dpath in glob.iglob(os.path.join(rootdir_inbox, '*')): - if os.path.isdir(dpath) and not os.listdir(dpath): - os.rmdir(dpath) - - # Delete left-over files in the database tables, that have been removed from disk: - with AADC_DB() as db: - db.cursor.execute("SELECT logid,uploadpath FROM flows.uploadlog WHERE uploadpath IS NOT NULL;") - for row in db.cursor.fetchall(): - if not os.path.isfile(os.path.join(rootdir_inbox, row['uploadpath'])): - print("MARK AS DELETED IN DATABASE: " + row['uploadpath']) - db.cursor.execute("UPDATE flows.uploadlog SET uploadpath=NULL,status='File deleted' WHERE logid=%s;", [row['logid']]) - db.conn.commit() - -#-------------------------------------------------------------------------------------------------- + """ + Cleanup of inbox directory + """ + rootdir_inbox = '/flows/inbox' + + # Just a simple check to begin with: + if not os.path.isdir(rootdir_inbox): + raise FileNotFoundError("INBOX could not be found.") + + # Remove empty directories: + for inputtype in ('science', 'templates', 'subtracted', 'photometry', 'replace'): + for dpath in glob.iglob(os.path.join(rootdir_inbox, '*', inputtype)): + if not os.listdir(dpath): + os.rmdir(dpath) + + for dpath in glob.iglob(os.path.join(rootdir_inbox, '*')): + if os.path.isdir(dpath) and not os.listdir(dpath): + os.rmdir(dpath) + + # Delete left-over files in the database tables, that have been removed from disk: + with AADC_DB() as db: + db.cursor.execute("SELECT logid,uploadpath FROM flows.uploadlog WHERE uploadpath IS NOT NULL;") + for row in db.cursor.fetchall(): + if not os.path.isfile(os.path.join(rootdir_inbox, row['uploadpath'])): + print("MARK AS DELETED IN DATABASE: " + row['uploadpath']) + db.cursor.execute("UPDATE flows.uploadlog SET uploadpath=NULL,status='File deleted' WHERE logid=%s;", + [row['logid']]) + db.conn.commit() + + +# -------------------------------------------------------------------------------------------------- if __name__ == '__main__': - logging_level = logging.INFO - - # Setup logging: - formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') - console = logging.StreamHandler(sys.stdout) - console.setFormatter(formatter) - logger = logging.getLogger(__name__) - if not logger.hasHandlers(): - logger.addHandler(console) - logger.setLevel(logging_level) - - # Add a CounterFilter to the logger, which will count the number of log-records - # being passed through the logger. Can be used to count the number of errors/warnings: - counter = CounterFilter() - logger.addFilter(counter) - - # Run the ingests and cleanup: - ingest_from_inbox() - ingest_photometry_from_inbox() - cleanup_inbox() - - # Check the number of errors or warnings issued, and convert these to a return-code: - logcounts = counter.counter - if logcounts.get('ERROR', 0) > 0 or logcounts.get('CRITICAL', 0) > 0: - sys.exit(4) - elif logcounts.get('WARNING', 0) > 0: - sys.exit(3) - sys.exit(0) + logging_level = logging.INFO + + # Setup logging: + formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') + console = logging.StreamHandler(sys.stdout) + console.setFormatter(formatter) + logger = logging.getLogger(__name__) + if not logger.hasHandlers(): + logger.addHandler(console) + logger.setLevel(logging_level) + + # Add a CounterFilter to the logger, which will count the number of log-records + # being passed through the logger. Can be used to count the number of errors/warnings: + counter = CounterFilter() + logger.addFilter(counter) + + # Run the ingests and cleanup: + ingest_from_inbox() + ingest_photometry_from_inbox() + cleanup_inbox() + + # Check the number of errors or warnings issued, and convert these to a return-code: + logcounts = counter.counter + if logcounts.get('ERROR', 0) > 0 or logcounts.get('CRITICAL', 0) > 0: + sys.exit(4) + elif logcounts.get('WARNING', 0) > 0: + sys.exit(3) + sys.exit(0) diff --git a/run_photometry.py b/run_photometry.py index 4614a6b..134f85e 100644 --- a/run_photometry.py +++ b/run_photometry.py @@ -15,173 +15,171 @@ import multiprocessing from flows import api, photometry, load_config + # -------------------------------------------------------------------------------------------------- -def process_fileid(fid, output_folder_root=None, attempt_imagematch=True, autoupload=False, - keep_diff_fixed=False, cm_timeout=None): - logger = logging.getLogger('flows') - logging.captureWarnings(True) - logger_warn = logging.getLogger('py.warnings') - formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s', "%Y-%m-%d %H:%M:%S") - - datafile = api.get_datafile(fid) - target_name = datafile['target_name'] - - # Folder to save output: - output_folder = os.path.join(output_folder_root, target_name, f'{fid:05d}') - - photfile = None - _filehandler = None - try: - # Set the status to indicate that we have started processing: - if autoupload: - api.set_photometry_status(fid, 'running') - - # Create the output directory if it doesn't exist: - os.makedirs(output_folder, exist_ok=True) - - # Also write any logging output to the - _filehandler = logging.FileHandler(os.path.join(output_folder, 'photometry.log'), mode='w') - _filehandler.setFormatter(formatter) - _filehandler.setLevel(logging.INFO) - logger.addHandler(_filehandler) - logger_warn.addHandler(_filehandler) - - photfile = photometry( - fileid=fid, - output_folder=output_folder, - attempt_imagematch=attempt_imagematch, - keep_diff_fixed=keep_diff_fixed, - cm_timeout=cm_timeout) - - except (SystemExit, KeyboardInterrupt): - logger.error("Aborted by user or system.") - if os.path.exists(output_folder): - shutil.rmtree(output_folder, ignore_errors=True) - photfile = None - if autoupload: - api.set_photometry_status(fid, 'abort') - - except: # noqa: E722, pragma: no cover - logger.exception("Photometry failed") - photfile = None - if autoupload: - api.set_photometry_status(fid, 'error') - - if _filehandler is not None: - logger.removeHandler(_filehandler) - logger_warn.removeHandler(_filehandler) - - if photfile is not None: - if autoupload: - api.upload_photometry(fid, delete_completed=True) - api.set_photometry_status(fid, 'ingest') - - return photfile +def process_fileid(fid, output_folder_root=None, attempt_imagematch=True, autoupload=False, keep_diff_fixed=False, + cm_timeout=None): + logger = logging.getLogger('flows') + logging.captureWarnings(True) + logger_warn = logging.getLogger('py.warnings') + formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s', "%Y-%m-%d %H:%M:%S") + + datafile = api.get_datafile(fid) + target_name = datafile['target_name'] + + # Folder to save output: + output_folder = os.path.join(output_folder_root, target_name, f'{fid:05d}') + + photfile = None + _filehandler = None + try: + # Set the status to indicate that we have started processing: + if autoupload: + api.set_photometry_status(fid, 'running') + + # Create the output directory if it doesn't exist: + os.makedirs(output_folder, exist_ok=True) + + # Also write any logging output to the + _filehandler = logging.FileHandler(os.path.join(output_folder, 'photometry.log'), mode='w') + _filehandler.setFormatter(formatter) + _filehandler.setLevel(logging.INFO) + logger.addHandler(_filehandler) + logger_warn.addHandler(_filehandler) + + photfile = photometry(fileid=fid, output_folder=output_folder, attempt_imagematch=attempt_imagematch, + keep_diff_fixed=keep_diff_fixed, cm_timeout=cm_timeout) + + except (SystemExit, KeyboardInterrupt): + logger.error("Aborted by user or system.") + if os.path.exists(output_folder): + shutil.rmtree(output_folder, ignore_errors=True) + photfile = None + if autoupload: + api.set_photometry_status(fid, 'abort') + + except: # noqa: E722, pragma: no cover + logger.exception("Photometry failed") + photfile = None + if autoupload: + api.set_photometry_status(fid, 'error') + + if _filehandler is not None: + logger.removeHandler(_filehandler) + logger_warn.removeHandler(_filehandler) + + if photfile is not None: + if autoupload: + api.upload_photometry(fid, delete_completed=True) + api.set_photometry_status(fid, 'ingest') + + return photfile + # -------------------------------------------------------------------------------------------------- def main(): - # Parse command line arguments: - parser = argparse.ArgumentParser(description='Run photometry pipeline.') - parser.add_argument('-d', '--debug', help='Print debug messages.', action='store_true') - parser.add_argument('-q', '--quiet', help='Only report warnings and errors.', action='store_true') - parser.add_argument('-o', '--overwrite', help='Overwrite existing results.', action='store_true') - - group = parser.add_argument_group('Selecting which files to process') - group.add_argument('--fileid', type=int, default=None, action='append', help="Process this file ID. Overrides all other filters.") - group.add_argument('--targetid', type=int, default=None, action='append', help="Only process files from this target.") - group.add_argument('--filter', type=str, default=None, choices=['missing', 'all', 'error']) - group.add_argument('--minversion', type=str, default=None, help="Include files not previously processed with at least this version.") - - group = parser.add_argument_group('Processing settings') - group.add_argument('--threads', type=int, default=1, help="Number of parallel threads to use.") - group.add_argument('--no-imagematch', action='store_true', help="Disable ImageMatch.") - group.add_argument('--autoupload', action='store_true', - help="Automatically upload completed photometry to Flows website. Only do this, if you know what you are doing!") - group.add_argument('--fixposdiff', action='store_true', - help="Fix SN position during PSF photometry of difference image. Useful when difference image is noisy.") - group.add_argument('--wcstimeout', type=int, default=None, help="Timeout in Seconds for WCS.") - args = parser.parse_args() - - # Ensure that all input has been given: - if not args.fileid and not args.targetid and args.filter is None: - parser.error("Please select either a specific FILEID .") - - # Set logging level: - logging_level = logging.INFO - if args.quiet: - logging_level = logging.WARNING - elif args.debug: - logging_level = logging.DEBUG - - # Number of threads to use: - threads = args.threads - if threads <= 0: - threads = multiprocessing.cpu_count() - - # Setup logging: - formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s', "%Y-%m-%d %H:%M:%S") - console = logging.StreamHandler(sys.stdout) - console.setFormatter(formatter) - logger = logging.getLogger('flows') - if not logger.hasHandlers(): - logger.addHandler(console) - logger.propagate = False - logger.setLevel(logging_level) - - # If we have started a new processing, perform a cleanup of the - # photometry status indicator. This will change all processes - # still marked as "running" to "abort" if they have been running - # for more than a day: - if args.autoupload: - api.cleanup_photometry_status() - - if args.fileid is not None: - # Run the specified fileids: - fileids = args.fileid - else: - # Ask the API for a list of fileids which are yet to be processed: - if args.targetid is not None: - fileids = [] - for targid in args.targetid: - fileids += api.get_datafiles(targetid=targid, filt=args.filter, minversion=args.minversion) - else: - fileids = api.get_datafiles(filt=args.filter, minversion=args.minversion) - - # Remove duplicates from fileids to be processed: - fileids = list(set(fileids)) - - # Ask the config where we should store the output: - config = load_config() - output_folder_root = config.get('photometry', 'output', fallback='.') - - # Create function wrapper: - process_fileid_wrapper = functools.partial( - process_fileid, - output_folder_root=output_folder_root, - attempt_imagematch=not args.no_imagematch, - autoupload=args.autoupload, - keep_diff_fixed=args.fixposdiff, - cm_timeout=args.wcstimeout) - - if threads > 1: - # Disable printing info messages from the parent function. - # It is going to be all jumbled up anyway. - logger.setLevel(logging.WARNING) - - # There is more than one area to process, so let's start - # a process pool and process them in parallel: - with multiprocessing.Pool(threads) as pool: - for res in pool.imap_unordered(process_fileid_wrapper, fileids): - pass - - else: - # Only single thread so simply run it directly: - for fid in fileids: - logger.info("=" * 72) - logger.info(fid) - logger.info("=" * 72) - process_fileid_wrapper(fid) - -#-------------------------------------------------------------------------------------------------- + # Parse command line arguments: + parser = argparse.ArgumentParser(description='Run photometry pipeline.') + parser.add_argument('-d', '--debug', help='Print debug messages.', action='store_true') + parser.add_argument('-q', '--quiet', help='Only report warnings and errors.', action='store_true') + parser.add_argument('-o', '--overwrite', help='Overwrite existing results.', action='store_true') + + group = parser.add_argument_group('Selecting which files to process') + group.add_argument('--fileid', type=int, default=None, action='append', + help="Process this file ID. Overrides all other filters.") + group.add_argument('--targetid', type=int, default=None, action='append', + help="Only process files from this target.") + group.add_argument('--filter', type=str, default=None, choices=['missing', 'all', 'error']) + group.add_argument('--minversion', type=str, default=None, + help="Include files not previously processed with at least this version.") + + group = parser.add_argument_group('Processing settings') + group.add_argument('--threads', type=int, default=1, help="Number of parallel threads to use.") + group.add_argument('--no-imagematch', action='store_true', help="Disable ImageMatch.") + group.add_argument('--autoupload', action='store_true', + help="Automatically upload completed photometry to Flows website. Only do this, if you know what you are doing!") + group.add_argument('--fixposdiff', action='store_true', + help="Fix SN position during PSF photometry of difference image. Useful when difference image is noisy.") + group.add_argument('--wcstimeout', type=int, default=None, help="Timeout in Seconds for WCS.") + args = parser.parse_args() + + # Ensure that all input has been given: + if not args.fileid and not args.targetid and args.filter is None: + parser.error("Please select either a specific FILEID .") + + # Set logging level: + logging_level = logging.INFO + if args.quiet: + logging_level = logging.WARNING + elif args.debug: + logging_level = logging.DEBUG + + # Number of threads to use: + threads = args.threads + if threads <= 0: + threads = multiprocessing.cpu_count() + + # Setup logging: + formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s', "%Y-%m-%d %H:%M:%S") + console = logging.StreamHandler(sys.stdout) + console.setFormatter(formatter) + logger = logging.getLogger('flows') + if not logger.hasHandlers(): + logger.addHandler(console) + logger.propagate = False + logger.setLevel(logging_level) + + # If we have started a new processing, perform a cleanup of the + # photometry status indicator. This will change all processes + # still marked as "running" to "abort" if they have been running + # for more than a day: + if args.autoupload: + api.cleanup_photometry_status() + + if args.fileid is not None: + # Run the specified fileids: + fileids = args.fileid + else: + # Ask the API for a list of fileids which are yet to be processed: + if args.targetid is not None: + fileids = [] + for targid in args.targetid: + fileids += api.get_datafiles(targetid=targid, filt=args.filter, minversion=args.minversion) + else: + fileids = api.get_datafiles(filt=args.filter, minversion=args.minversion) + + # Remove duplicates from fileids to be processed: + fileids = list(set(fileids)) + + # Ask the config where we should store the output: + config = load_config() + output_folder_root = config.get('photometry', 'output', fallback='.') + + # Create function wrapper: + process_fileid_wrapper = functools.partial(process_fileid, output_folder_root=output_folder_root, + attempt_imagematch=not args.no_imagematch, autoupload=args.autoupload, + keep_diff_fixed=args.fixposdiff, cm_timeout=args.wcstimeout) + + if threads > 1: + # Disable printing info messages from the parent function. + # It is going to be all jumbled up anyway. + logger.setLevel(logging.WARNING) + + # There is more than one area to process, so let's start + # a process pool and process them in parallel: + with multiprocessing.Pool(threads) as pool: + for res in pool.imap_unordered(process_fileid_wrapper, fileids): + pass + + else: + # Only single thread so simply run it directly: + for fid in fileids: + logger.info("=" * 72) + logger.info(fid) + logger.info("=" * 72) + process_fileid_wrapper(fid) + + +# -------------------------------------------------------------------------------------------------- if __name__ == '__main__': - main() + main() diff --git a/run_plotlc.py b/run_plotlc.py index 519fbfd..938be88 100644 --- a/run_plotlc.py +++ b/run_plotlc.py @@ -18,131 +18,132 @@ import mplcursors import seaborn as sns -#-------------------------------------------------------------------------------------------------- + +# -------------------------------------------------------------------------------------------------- def main(): - # All available filters: - all_filters = list(api.get_filters().keys()) + # All available filters: + all_filters = list(api.get_filters().keys()) - # Parser: - parser = argparse.ArgumentParser(description='Plot photometry for target') - parser.add_argument('--target', '-t', type=str, required=True, help="""Target identifier: + # Parser: + parser = argparse.ArgumentParser(description='Plot photometry for target') + parser.add_argument('--target', '-t', type=str, required=True, help="""Target identifier: Can be either the SN name (e.g. 2019yvr) or the Flows target ID.""") - parser.add_argument('--fileid', '-i', type=int, default=None, action='append', help='Specific file ids.') - parser.add_argument('--filter', '-f', type=str, default=None, action='append', choices=all_filters, - help=f'Photmetric filter to plot. If not provided will plot all. Choose between {all_filters}') - parser.add_argument('--offset', '-jd', type=float, default=2458800.0) - parser.add_argument('--subonly', action='store_true', help='Only show template subtracted data points.') - args = parser.parse_args() - - # To use when only plotting some filters - usefilts = args.filters - if usefilts is not None: - usefilts = set(args.filters) - - # To use when only plotting some fileids - # Parse input fileids: - if args.fileid is not None: - # Plot the specified fileid: - fileids = args.fileid - else: - fileids = [] - if len(fileids) > 1: - raise NotImplementedError("This has not been implemented yet") - - # Get the name of the target: - snname = args.target - if snname.isdigit(): - datafiles = api.get_datafiles(int(snname), filt='all') - snname = api.get_datafile(datafiles[0])['target_name'] - - # Change to directory, raise if it does not exist - config = load_config() - workdir_root = config.get('photometry', 'output', fallback='.') - sndir = os.path.join(workdir_root, snname) - if not os.path.isdir(sndir): - print('No such directory as',sndir) - return - - # Get list of photometry files - phot_files = glob.iglob(os.path.join(sndir, '*', 'photometry.ecsv')) - - # Load all data into astropy table - tablerows = [] - for file in phot_files: - # Load photometry file into Table: - AT = Table.read(file, format='ascii.ecsv') - - # Pull out meta-data: - fileid = AT.meta['fileid'] - filt = AT.meta['photfilter'] - jd = Time(AT.meta['obstime-bmjd'], format='mjd', scale='tdb').jd - - # get phot of diff image - AT.add_index('starid') - if -1 in AT['starid']: - mag, mag_err = AT.loc[-1]['mag'], AT.loc[-1]['mag_error'] - sub = True - elif 0 in AT['starid']: - print('No subtraction found for:',file,'in filter',filt) - mag,mag_err = AT.loc[0]['mag'],AT.loc[0]['mag_error'] - sub = False - else: - print('No object phot found, skipping: \n',file) - continue - - tablerows.append((jd, mag, mag_err, filt, sub, fileid)) - - phot = Table( - rows=tablerows, - names=['jd','mag','mag_err','filter','sub','fileid'], - dtype=['float64','float64','float64','S64','bool','int64']) - - # Create list of filters to plot: - filters = list(np.unique(phot['filter'])) - if usefilts: - filters = set(filters).intersection(usefilts) - - # Split photometry table - shifts = dict(zip(filters, np.zeros(len(filters)))) - - # Create the plot: - plots_interactive() - sns.set(style='ticks') - dpi_mult = 1 if not args.subonly else 2 - fig, ax = plt.subplots(figsize=(6.4,4), dpi=130*dpi_mult) - fig.subplots_adjust(top=0.95, left=0.1, bottom=0.1, right=0.97) - - cps = sns.color_palette() - colors = dict(zip(filters,(cps[2],cps[3],cps[0],cps[-1],cps[1]))) - - if args.subonly: - for filt in filters: - lc = phot[(phot['filter'] == filt) & phot['sub']] - ax.errorbar(lc['jd'] - args.offset, lc['mag'] + shifts[filt], lc['mag_err'], - marker='s', linestyle='None', label=filt, color=colors[filt]) - - else: - for filt in filters: - lc = phot[phot['filter'] == filt] - ax.errorbar(lc['jd'] - args.offset, lc['mag'] + shifts[filt], lc['mag_err'], - marker='s', linestyle='None', label=filt, color=colors[filt]) - - ax.invert_yaxis() - ax.legend() - ax.set_xlabel('JD - ' + str(args.offset), fontsize=16) - ax.set_ylabel('App. Mag', fontsize=16) - ax.set_title(snname) - - # Make the points interactive: - def annotate(sel): - lc = phot[phot['filter'] == str(sel.artist.get_label())] - point = lc[sel.target.index] - point = dict(zip(point.colnames, point)) # Convert table row to dict - return sel.annotation.set_text("Fileid: {fileid:d}\nJD: {jd:.3f}\nMag: {mag:.2f}$\\pm${mag_err:.2f}".format(**point)) - - mplcursors.cursor(ax).connect("add", annotate) - plt.show(block=True) - -#-------------------------------------------------------------------------------------------------- + parser.add_argument('--fileid', '-i', type=int, default=None, action='append', help='Specific file ids.') + parser.add_argument('--filter', '-f', type=str, default=None, action='append', choices=all_filters, + help=f'Photmetric filter to plot. If not provided will plot all. Choose between {all_filters}') + parser.add_argument('--offset', '-jd', type=float, default=2458800.0) + parser.add_argument('--subonly', action='store_true', help='Only show template subtracted data points.') + args = parser.parse_args() + + # To use when only plotting some filters + usefilts = args.filters + if usefilts is not None: + usefilts = set(args.filters) + + # To use when only plotting some fileids + # Parse input fileids: + if args.fileid is not None: + # Plot the specified fileid: + fileids = args.fileid + else: + fileids = [] + if len(fileids) > 1: + raise NotImplementedError("This has not been implemented yet") + + # Get the name of the target: + snname = args.target + if snname.isdigit(): + datafiles = api.get_datafiles(int(snname), filt='all') + snname = api.get_datafile(datafiles[0])['target_name'] + + # Change to directory, raise if it does not exist + config = load_config() + workdir_root = config.get('photometry', 'output', fallback='.') + sndir = os.path.join(workdir_root, snname) + if not os.path.isdir(sndir): + print('No such directory as', sndir) + return + + # Get list of photometry files + phot_files = glob.iglob(os.path.join(sndir, '*', 'photometry.ecsv')) + + # Load all data into astropy table + tablerows = [] + for file in phot_files: + # Load photometry file into Table: + AT = Table.read(file, format='ascii.ecsv') + + # Pull out meta-data: + fileid = AT.meta['fileid'] + filt = AT.meta['photfilter'] + jd = Time(AT.meta['obstime-bmjd'], format='mjd', scale='tdb').jd + + # get phot of diff image + AT.add_index('starid') + if -1 in AT['starid']: + mag, mag_err = AT.loc[-1]['mag'], AT.loc[-1]['mag_error'] + sub = True + elif 0 in AT['starid']: + print('No subtraction found for:', file, 'in filter', filt) + mag, mag_err = AT.loc[0]['mag'], AT.loc[0]['mag_error'] + sub = False + else: + print('No object phot found, skipping: \n', file) + continue + + tablerows.append((jd, mag, mag_err, filt, sub, fileid)) + + phot = Table(rows=tablerows, names=['jd', 'mag', 'mag_err', 'filter', 'sub', 'fileid'], + dtype=['float64', 'float64', 'float64', 'S64', 'bool', 'int64']) + + # Create list of filters to plot: + filters = list(np.unique(phot['filter'])) + if usefilts: + filters = set(filters).intersection(usefilts) + + # Split photometry table + shifts = dict(zip(filters, np.zeros(len(filters)))) + + # Create the plot: + plots_interactive() + sns.set(style='ticks') + dpi_mult = 1 if not args.subonly else 2 + fig, ax = plt.subplots(figsize=(6.4, 4), dpi=130 * dpi_mult) + fig.subplots_adjust(top=0.95, left=0.1, bottom=0.1, right=0.97) + + cps = sns.color_palette() + colors = dict(zip(filters, (cps[2], cps[3], cps[0], cps[-1], cps[1]))) + + if args.subonly: + for filt in filters: + lc = phot[(phot['filter'] == filt) & phot['sub']] + ax.errorbar(lc['jd'] - args.offset, lc['mag'] + shifts[filt], lc['mag_err'], marker='s', linestyle='None', + label=filt, color=colors[filt]) + + else: + for filt in filters: + lc = phot[phot['filter'] == filt] + ax.errorbar(lc['jd'] - args.offset, lc['mag'] + shifts[filt], lc['mag_err'], marker='s', linestyle='None', + label=filt, color=colors[filt]) + + ax.invert_yaxis() + ax.legend() + ax.set_xlabel('JD - ' + str(args.offset), fontsize=16) + ax.set_ylabel('App. Mag', fontsize=16) + ax.set_title(snname) + + # Make the points interactive: + def annotate(sel): + lc = phot[phot['filter'] == str(sel.artist.get_label())] + point = lc[sel.target.index] + point = dict(zip(point.colnames, point)) # Convert table row to dict + return sel.annotation.set_text( + "Fileid: {fileid:d}\nJD: {jd:.3f}\nMag: {mag:.2f}$\\pm${mag_err:.2f}".format(**point)) + + mplcursors.cursor(ax).connect("add", annotate) + plt.show(block=True) + + +# -------------------------------------------------------------------------------------------------- if __name__ == '__main__': - main() + main() diff --git a/run_querytns.py b/run_querytns.py index 1b11a3c..e837adc 100644 --- a/run_querytns.py +++ b/run_querytns.py @@ -18,117 +18,112 @@ from datetime import datetime, timedelta, timezone from flows import api, tns -#-------------------------------------------------------------------------------------------------- + +# -------------------------------------------------------------------------------------------------- def main(): - # Parse command line arguments: - parser = argparse.ArgumentParser(description='Query TNS and upload to Flows candidates.') - parser.add_argument('-d', '--debug', action='store_true', help='Print debug messages.') - parser.add_argument('-q', '--quiet', action='store_true', help='Only report warnings and errors.') - parser.add_argument('--zmax', type=float, default=0.105, help='Maximum redshift.') - parser.add_argument('--zmin', type=float, default=0.000000001, help='Minimum redshift.') - parser.add_argument('-b', '--days_begin', type=int, default=30, help='Discovery day at least X days before today.') - parser.add_argument('-e', '--days_end', type=int, default=3, help='Discovery day at most X days before today.') - parser.add_argument('-o', '--objtype', type=str, default=[3, 104], help='TNS objtype int given as comma separed string with no spaces') - parser.add_argument('-m', '--limit_months', type=int, default=2, help='Integer number of months to limit TNS search (for speed). \ + # Parse command line arguments: + parser = argparse.ArgumentParser(description='Query TNS and upload to Flows candidates.') + parser.add_argument('-d', '--debug', action='store_true', help='Print debug messages.') + parser.add_argument('-q', '--quiet', action='store_true', help='Only report warnings and errors.') + parser.add_argument('--zmax', type=float, default=0.105, help='Maximum redshift.') + parser.add_argument('--zmin', type=float, default=0.000000001, help='Minimum redshift.') + parser.add_argument('-b', '--days_begin', type=int, default=30, help='Discovery day at least X days before today.') + parser.add_argument('-e', '--days_end', type=int, default=3, help='Discovery day at most X days before today.') + parser.add_argument('-o', '--objtype', type=str, default=[3, 104], + help='TNS objtype int given as comma separed string with no spaces') + parser.add_argument('-m', '--limit_months', type=int, default=2, help='Integer number of months to limit TNS search (for speed). \ Should be greater than days_begin.') - parser.add_argument('--autoupload', action='store_true', help="Automatically upload targets to Flows website. Only do this, if you know what you are doing!") - args = parser.parse_args() - - # Set logging level: - logging_level = logging.INFO - if args.quiet: - logging_level = logging.WARNING - elif args.debug: - logging_level = logging.DEBUG - - # Setup logging: - formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') - console = logging.StreamHandler() - console.setFormatter(formatter) - logger = logging.getLogger(__name__) - if not logger.hasHandlers(): - logger.addHandler(console) - logger.setLevel(logging_level) - - tqdm_settings = {'disable': None if logger.isEnabledFor(logging.INFO) else True} - - # Try to load TNS config - only used for early stopping - try: - tns._load_tns_config() - except tns.TNSConfigError: - parser.error("Error in TNS configuration.") - return - - # Calculate current date and date range to search: - date_now = datetime.now(timezone.utc).date() - date_end = date_now - timedelta(days=args.days_end) - date_begin = date_now - timedelta(days=args.days_begin) - logger.info('Date begin = %s, date_end = %s', date_begin, date_end) - - # Query TNS for SN names - logger.info('Querying TNS for all targets, this may take awhile') - nms = tns.tns_getnames( - months=args.limit_months, # pre-limit TNS search to candidates reported in the last X months - date_begin=date_begin, - date_end=date_end, - zmin=args.zmin, - zmax=args.zmax, - objtype=args.objtype # Relevant TNS SN Ia subtypes. - ) - logger.debug(nms) - - if not nms: - logger.info("No targets were found.") - return - - # Remove already existing names using flows api - included_names = ['SN' + target['target_name'] for target in api.get_targets()] - nms = list(set(nms) - set(included_names)) - logger.info('Target names obtained: %s', nms) - - # Regular Expression matching any string starting with "ztf" - regex_ztf = re.compile('^ztf', flags=re.IGNORECASE) - regex_sn = re.compile(r'^sn\s*', flags=re.IGNORECASE) - - # Query TNS for object info using API, then upload to FLOWS using API. - num_uploaded = 0 - if args.autoupload: - for name in tqdm(nms, **tqdm_settings): - sn = regex_sn.sub('', name) - logger.debug('querying TNS for: %s', sn) - - # make GET request to TNS via API - reply = tns.tns_get_obj(sn) - - # Parse output - if reply: - logger.debug('GET query successful') - - # Extract object info - coord = SkyCoord(ra=reply['radeg'], dec=reply['decdeg'], unit='deg', frame='icrs') - discovery_date = Time(reply['discoverydate'], format='iso', scale='utc') - ztf = list(filter(regex_ztf.match, reply['internal_names'])) - ztf = None if not ztf else ztf[0] - if 'object_type' in reply and 'name' in reply['object_type']: - sntype = regex_sn.sub('', reply['object_type']['name']) - else: - sntype = None - - # Try to upload to FLOWS - newtargetid = api.add_target(reply['objname'], coord, - redshift=reply['redshift'], - discovery_date=discovery_date, - discovery_mag=reply['discoverymag'], - host_galaxy=reply['hostname'], - ztf=ztf, - sntype=sntype, - status='candidate', - project='flows') - logger.debug('Uploaded to FLOWS with targetid=%d', newtargetid) - num_uploaded += 1 - - logger.info("%d targets uploaded to Flows.", num_uploaded) - -#-------------------------------------------------------------------------------------------------- + parser.add_argument('--autoupload', action='store_true', + help="Automatically upload targets to Flows website. Only do this, if you know what you are doing!") + args = parser.parse_args() + + # Set logging level: + logging_level = logging.INFO + if args.quiet: + logging_level = logging.WARNING + elif args.debug: + logging_level = logging.DEBUG + + # Setup logging: + formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') + console = logging.StreamHandler() + console.setFormatter(formatter) + logger = logging.getLogger(__name__) + if not logger.hasHandlers(): + logger.addHandler(console) + logger.setLevel(logging_level) + + tqdm_settings = {'disable': None if logger.isEnabledFor(logging.INFO) else True} + + # Try to load TNS config - only used for early stopping + try: + tns._load_tns_config() + except tns.TNSConfigError: + parser.error("Error in TNS configuration.") + return + + # Calculate current date and date range to search: + date_now = datetime.now(timezone.utc).date() + date_end = date_now - timedelta(days=args.days_end) + date_begin = date_now - timedelta(days=args.days_begin) + logger.info('Date begin = %s, date_end = %s', date_begin, date_end) + + # Query TNS for SN names + logger.info('Querying TNS for all targets, this may take awhile') + nms = tns.tns_getnames(months=args.limit_months, # pre-limit TNS search to candidates reported in the last X months + date_begin=date_begin, date_end=date_end, zmin=args.zmin, zmax=args.zmax, + objtype=args.objtype# Relevant TNS SN Ia subtypes. + ) + logger.debug(nms) + + if not nms: + logger.info("No targets were found.") + return + + # Remove already existing names using flows api + included_names = ['SN' + target['target_name'] for target in api.get_targets()] + nms = list(set(nms) - set(included_names)) + logger.info('Target names obtained: %s', nms) + + # Regular Expression matching any string starting with "ztf" + regex_ztf = re.compile('^ztf', flags=re.IGNORECASE) + regex_sn = re.compile(r'^sn\s*', flags=re.IGNORECASE) + + # Query TNS for object info using API, then upload to FLOWS using API. + num_uploaded = 0 + if args.autoupload: + for name in tqdm(nms, **tqdm_settings): + sn = regex_sn.sub('', name) + logger.debug('querying TNS for: %s', sn) + + # make GET request to TNS via API + reply = tns.tns_get_obj(sn) + + # Parse output + if reply: + logger.debug('GET query successful') + + # Extract object info + coord = SkyCoord(ra=reply['radeg'], dec=reply['decdeg'], unit='deg', frame='icrs') + discovery_date = Time(reply['discoverydate'], format='iso', scale='utc') + ztf = list(filter(regex_ztf.match, reply['internal_names'])) + ztf = None if not ztf else ztf[0] + if 'object_type' in reply and 'name' in reply['object_type']: + sntype = regex_sn.sub('', reply['object_type']['name']) + else: + sntype = None + + # Try to upload to FLOWS + newtargetid = api.add_target(reply['objname'], coord, redshift=reply['redshift'], + discovery_date=discovery_date, discovery_mag=reply['discoverymag'], + host_galaxy=reply['hostname'], ztf=ztf, sntype=sntype, status='candidate', + project='flows') + logger.debug('Uploaded to FLOWS with targetid=%d', newtargetid) + num_uploaded += 1 + + logger.info("%d targets uploaded to Flows.", num_uploaded) + + +# -------------------------------------------------------------------------------------------------- if __name__ == '__main__': - main() + main() diff --git a/run_upload_photometry.py b/run_upload_photometry.py index e4f859c..1480242 100644 --- a/run_upload_photometry.py +++ b/run_upload_photometry.py @@ -10,35 +10,37 @@ import logging from flows import api -#-------------------------------------------------------------------------------------------------- + +# -------------------------------------------------------------------------------------------------- def main(): - # Parse command line arguments: - parser = argparse.ArgumentParser(description='Upload photometry.') - parser.add_argument('-d', '--debug', help='Print debug messages.', action='store_true') - parser.add_argument('-q', '--quiet', help='Only report warnings and errors.', action='store_true') - parser.add_argument('fileids', type=int, help='File IDs to be uploaded.', nargs='+') - args = parser.parse_args() - - # Set logging level: - logging_level = logging.INFO - if args.quiet: - logging_level = logging.WARNING - elif args.debug: - logging_level = logging.DEBUG - - # Setup logging: - formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') - console = logging.StreamHandler() - console.setFormatter(formatter) - logger = logging.getLogger('flows') - if not logger.hasHandlers(): - logger.addHandler(console) - logger.setLevel(logging_level) - - # Loop through the fileids and upload the results: - for fid in args.fileids: - api.upload_photometry(fid) - -#-------------------------------------------------------------------------------------------------- + # Parse command line arguments: + parser = argparse.ArgumentParser(description='Upload photometry.') + parser.add_argument('-d', '--debug', help='Print debug messages.', action='store_true') + parser.add_argument('-q', '--quiet', help='Only report warnings and errors.', action='store_true') + parser.add_argument('fileids', type=int, help='File IDs to be uploaded.', nargs='+') + args = parser.parse_args() + + # Set logging level: + logging_level = logging.INFO + if args.quiet: + logging_level = logging.WARNING + elif args.debug: + logging_level = logging.DEBUG + + # Setup logging: + formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') + console = logging.StreamHandler() + console.setFormatter(formatter) + logger = logging.getLogger('flows') + if not logger.hasHandlers(): + logger.addHandler(console) + logger.setLevel(logging_level) + + # Loop through the fileids and upload the results: + for fid in args.fileids: + api.upload_photometry(fid) + + +# -------------------------------------------------------------------------------------------------- if __name__ == '__main__': - main() + main() diff --git a/run_visibility.py b/run_visibility.py index 574de4f..05934ab 100644 --- a/run_visibility.py +++ b/run_visibility.py @@ -9,15 +9,15 @@ import flows if __name__ == '__main__': - # Parse command line arguments: - parser = argparse.ArgumentParser(description='Run photometry pipeline.') - parser.add_argument('-t', '--target', type=str, help='TIC identifier of target.', nargs='?', default=2) - parser.add_argument('-s', '--site', type=int, help='TIC identifier of target.', nargs='?', default=None) - parser.add_argument('-d', '--date', type=str, help='TIC identifier of target.', nargs='?', default=None) - parser.add_argument('-o', '--output', type=str, help='TIC identifier of target.', nargs='?', default=None) - args = parser.parse_args() + # Parse command line arguments: + parser = argparse.ArgumentParser(description='Run photometry pipeline.') + parser.add_argument('-t', '--target', type=str, help='TIC identifier of target.', nargs='?', default=2) + parser.add_argument('-s', '--site', type=int, help='TIC identifier of target.', nargs='?', default=None) + parser.add_argument('-d', '--date', type=str, help='TIC identifier of target.', nargs='?', default=None) + parser.add_argument('-o', '--output', type=str, help='TIC identifier of target.', nargs='?', default=None) + args = parser.parse_args() - if args.output is None: - plots_interactive() + if args.output is None: + plots_interactive() - flows.visibility(target=args.target, siteid=args.site, date=args.date, output=args.output) + flows.visibility(target=args.target, siteid=args.site, date=args.date, output=args.output) diff --git a/setup.cfg b/setup.cfg index 4af7b14..8c93ff0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,16 +1,11 @@ [flake8] exclude = .git,__pycache__,notes -max-line-length = 99 +max-line-length = 120 # To be compliant with black +extend-ignore = E203 # To be compliant with black # Enable flake8-logging-format: enable-extensions = G -# Configuration of flake8-tabs: -use-flake8-tabs = True -blank-lines-indent = never -indent-tabs-def = 1 -indent-style = tab - ignore = E117, # over-indented (set when using tabs) E127, # continuation line over-indented for visual indent diff --git a/tests/conftest.py b/tests/conftest.py index 89638b0..ea9d2db 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,60 +9,57 @@ import pytest import sys import os -#import shutil +# import shutil import configparser import subprocess import shlex if sys.path[0] != os.path.abspath(os.path.join(os.path.dirname(__file__), '..')): - sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) -#-------------------------------------------------------------------------------------------------- + +# -------------------------------------------------------------------------------------------------- def capture_cli(script, params=[], mpiexec=False): + if isinstance(params, str): + params = shlex.split(params) - if isinstance(params, str): - params = shlex.split(params) + cmd = [sys.executable, script.strip()] + list(params) + if mpiexec: + cmd = ['mpiexec', '-n', '2'] + cmd - cmd = [sys.executable, script.strip()] + list(params) - if mpiexec: - cmd = ['mpiexec', '-n', '2'] + cmd + print(f"Command: {cmd}") + proc = subprocess.Popen(cmd, cwd=os.path.join(os.path.dirname(__file__), '..'), stdout=subprocess.PIPE, + stderr=subprocess.PIPE, universal_newlines=True) + out, err = proc.communicate() + exitcode = proc.returncode + proc.kill() - print(f"Command: {cmd}") - proc = subprocess.Popen(cmd, - cwd=os.path.join(os.path.dirname(__file__), '..'), - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - universal_newlines=True - ) - out, err = proc.communicate() - exitcode = proc.returncode - proc.kill() + print(f"ExitCode: {exitcode:d}") + print("StdOut:\n%s" % out.strip()) + print("StdErr:\n%s" % err.strip()) + return out, err, exitcode - print(f"ExitCode: {exitcode:d}") - print("StdOut:\n%s" % out.strip()) - print("StdErr:\n%s" % err.strip()) - return out, err, exitcode -#-------------------------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------------------------- @pytest.fixture(scope='session') def SETUP_CONFIG(): - """ - Fixture which sets up a dummy config-file which allows for simple testing only. - """ - config_file = os.path.join(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'flows', 'config.ini'))) - if os.path.exists(config_file): - yield config_file - else: - confstr = os.environ.get('FLOWS_CONFIG') - if confstr is None: - raise RuntimeError("Config file can not be set up.") + """ + Fixture which sets up a dummy config-file which allows for simple testing only. + """ + config_file = os.path.join(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'flows', 'config.ini'))) + if os.path.exists(config_file): + yield config_file + else: + confstr = os.environ.get('FLOWS_CONFIG') + if confstr is None: + raise RuntimeError("Config file can not be set up.") - # Write minimal config file that can be used for testing: - config = configparser.ConfigParser() - config.read_string(confstr) - with open(config_file, 'w') as fid: - config.write(fid) - fid.flush() + # Write minimal config file that can be used for testing: + config = configparser.ConfigParser() + config.read_string(confstr) + with open(config_file, 'w') as fid: + config.write(fid) + fid.flush() - yield config_file - os.remove(config_file) + yield config_file + os.remove(config_file) diff --git a/tests/test_api.py b/tests/test_api.py index c630097..3fac1a9 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -12,159 +12,161 @@ import numpy as np from astropy.coordinates import EarthLocation from astropy.table import Table -import conftest # noqa: F401 +import conftest # noqa: F401 from flows import api, load_config -#-------------------------------------------------------------------------------------------------- + +# -------------------------------------------------------------------------------------------------- def test_api_get_targets(SETUP_CONFIG): + tab = api.get_targets() + print(tab) - tab = api.get_targets() - print(tab) + assert isinstance(tab, list) + assert len(tab) > 0 + for target in tab: + assert isinstance(target, dict) + assert 'target_name' in target + assert 'targetid' in target + assert 'ra' in target + assert 'decl' in target + assert 'target_status' in target - assert isinstance(tab, list) - assert len(tab) > 0 - for target in tab: - assert isinstance(target, dict) - assert 'target_name' in target - assert 'targetid' in target - assert 'ra' in target - assert 'decl' in target - assert 'target_status' in target -#-------------------------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------------------------- def test_api_get_target(SETUP_CONFIG): + tab = api.get_target(2) + print(tab) - tab = api.get_target(2) - print(tab) + assert isinstance(tab, dict) + assert tab['target_name'] == '2019yvr' + assert tab['targetid'] == 2 + assert tab['target_status'] == 'target' + assert tab['ztf_id'] == 'ZTF20aabqkxs' - assert isinstance(tab, dict) - assert tab['target_name'] == '2019yvr' - assert tab['targetid'] == 2 - assert tab['target_status'] == 'target' - assert tab['ztf_id'] == 'ZTF20aabqkxs' -#-------------------------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------------------------- def test_api_get_datafiles(SETUP_CONFIG): + tab = api.get_datafiles(targetid=2, filt='all') + print(tab) + assert isinstance(tab, list) + assert len(tab) > 0 + for fid in tab: + assert isinstance(fid, int) - tab = api.get_datafiles(targetid=2, filt='all') - print(tab) - assert isinstance(tab, list) - assert len(tab) > 0 - for fid in tab: - assert isinstance(fid, int) + fileid = tab[0] + tab = api.get_datafile(fileid) + print(tab) + assert tab['fileid'] == fileid + assert tab['targetid'] == 2 - fileid = tab[0] - tab = api.get_datafile(fileid) - print(tab) - assert tab['fileid'] == fileid - assert tab['targetid'] == 2 -#-------------------------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------------------------- def test_api_get_filters(SETUP_CONFIG): + tab = api.get_filters() + print(tab) + assert isinstance(tab, dict) + for key, value in tab.items(): + assert isinstance(value, dict) + assert value['photfilter'] == key + assert 'wavelength_center' in value - tab = api.get_filters() - print(tab) - assert isinstance(tab, dict) - for key, value in tab.items(): - assert isinstance(value, dict) - assert value['photfilter'] == key - assert 'wavelength_center' in value -#-------------------------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------------------------- def test_api_get_sites(SETUP_CONFIG): - - tab = api.get_all_sites() - print(tab) - assert isinstance(tab, list) - assert len(tab) > 0 - for site in tab: - assert isinstance(site, dict) - assert isinstance(site['siteid'], int) - assert 'sitename' in site - assert isinstance(site['EarthLocation'], EarthLocation) - - site0 = tab[0] - print(site0) - tab = api.get_site(site0['siteid']) - print(tab) - assert isinstance(tab, dict) - assert tab == site0 - -#-------------------------------------------------------------------------------------------------- + tab = api.get_all_sites() + print(tab) + assert isinstance(tab, list) + assert len(tab) > 0 + for site in tab: + assert isinstance(site, dict) + assert isinstance(site['siteid'], int) + assert 'sitename' in site + assert isinstance(site['EarthLocation'], EarthLocation) + + site0 = tab[0] + print(site0) + tab = api.get_site(site0['siteid']) + print(tab) + assert isinstance(tab, dict) + assert tab == site0 + + +# -------------------------------------------------------------------------------------------------- def test_api_get_catalog(SETUP_CONFIG): + cat = api.get_catalog(2, output='table') + print(cat) - cat = api.get_catalog(2, output='table') - print(cat) + assert isinstance(cat, dict) - assert isinstance(cat, dict) + target = cat['target'] + assert isinstance(target, Table) + assert len(target) == 1 + assert target['targetid'] == 2 + assert target['target_name'] == '2019yvr' - target = cat['target'] - assert isinstance(target, Table) - assert len(target) == 1 - assert target['targetid'] == 2 - assert target['target_name'] == '2019yvr' + ref = cat['references'] + assert isinstance(ref, Table) - ref = cat['references'] - assert isinstance(ref, Table) + avoid = cat['avoid'] + assert isinstance(avoid, Table) - avoid = cat['avoid'] - assert isinstance(avoid, Table) -#-------------------------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------------------------- def test_api_get_lightcurve(SETUP_CONFIG): + tab = api.get_lightcurve(2) + print(tab) - tab = api.get_lightcurve(2) - print(tab) + assert isinstance(tab, Table) + assert len(tab) > 0 + assert 'time' in tab.colnames + assert 'mag_raw' in tab.colnames - assert isinstance(tab, Table) - assert len(tab) > 0 - assert 'time' in tab.colnames - assert 'mag_raw' in tab.colnames -#-------------------------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------------------------- def test_api_get_photometry(SETUP_CONFIG): - with tempfile.TemporaryDirectory() as tmpdir: - # Set cache to the temporary directory: - # FIXME: There is a potential race condition here! - config = load_config() - config.set('api', 'photometry_cache', tmpdir) - print(config) - - # The cache file should NOT exists: - assert not os.path.isfile(os.path.join(tmpdir, 'photometry-499.ecsv')), "Cache file already exists" - - # Download a photometry from API: - tab = api.get_photometry(499) - print(tab) - - # Basic tests of table: - assert isinstance(tab, Table) - assert len(tab) > 0 - assert 'starid' in tab.colnames - assert 'ra' in tab.colnames - assert 'decl' in tab.colnames - assert 'mag' in tab.colnames - assert 'mag_error' in tab.colnames - assert np.sum(tab['starid'] == 0) == 1, "There should be one starid=0" - - # Meta-information: - assert tab.meta['targetid'] == 2 - assert tab.meta['fileid'] == 179 - assert tab.meta['photfilter'] == 'B' - - # The cache file should now exists: - assert os.path.isfile(os.path.join(tmpdir, 'photometry-499.ecsv')), "Cache file does not exist" - - # Asking for the same photometry should now load from cache: - tab2 = api.get_photometry(499) - print(tab2) - - # The two tables should be identical: - assert tab2.meta == tab.meta - assert tab2.colnames == tab.colnames - for col in tab.colnames: - np.testing.assert_allclose(tab2[col], tab[col], equal_nan=True) - -#-------------------------------------------------------------------------------------------------- + with tempfile.TemporaryDirectory() as tmpdir: + # Set cache to the temporary directory: + # FIXME: There is a potential race condition here! + config = load_config() + config.set('api', 'photometry_cache', tmpdir) + print(config) + + # The cache file should NOT exists: + assert not os.path.isfile(os.path.join(tmpdir, 'photometry-499.ecsv')), "Cache file already exists" + + # Download a photometry from API: + tab = api.get_photometry(499) + print(tab) + + # Basic tests of table: + assert isinstance(tab, Table) + assert len(tab) > 0 + assert 'starid' in tab.colnames + assert 'ra' in tab.colnames + assert 'decl' in tab.colnames + assert 'mag' in tab.colnames + assert 'mag_error' in tab.colnames + assert np.sum(tab['starid'] == 0) == 1, "There should be one starid=0" + + # Meta-information: + assert tab.meta['targetid'] == 2 + assert tab.meta['fileid'] == 179 + assert tab.meta['photfilter'] == 'B' + + # The cache file should now exists: + assert os.path.isfile(os.path.join(tmpdir, 'photometry-499.ecsv')), "Cache file does not exist" + + # Asking for the same photometry should now load from cache: + tab2 = api.get_photometry(499) + print(tab2) + + # The two tables should be identical: + assert tab2.meta == tab.meta + assert tab2.colnames == tab.colnames + for col in tab.colnames: + np.testing.assert_allclose(tab2[col], tab[col], equal_nan=True) + + +# -------------------------------------------------------------------------------------------------- if __name__ == '__main__': - pytest.main([__file__]) + pytest.main([__file__]) diff --git a/tests/test_catalogs.py b/tests/test_catalogs.py index c82d21f..9771d53 100644 --- a/tests/test_catalogs.py +++ b/tests/test_catalogs.py @@ -10,107 +10,92 @@ import numpy as np from astropy.coordinates import SkyCoord from astropy.table import Table -import conftest # noqa: F401 +import conftest # noqa: F401 from flows import catalogs -#-------------------------------------------------------------------------------------------------- + +# -------------------------------------------------------------------------------------------------- def test_query_simbad(): + # Coordinates around test-object (2019yvr): + coo_centre = SkyCoord(ra=256.727512, dec=30.271482, unit='deg', frame='icrs') - # Coordinates around test-object (2019yvr): - coo_centre = SkyCoord( - ra=256.727512, - dec=30.271482, - unit='deg', - frame='icrs' - ) + results, simbad = catalogs.query_simbad(coo_centre) - results, simbad = catalogs.query_simbad(coo_centre) + assert isinstance(results, Table) + assert isinstance(simbad, SkyCoord) + assert len(results) > 0 + results.pprint_all(50) - assert isinstance(results, Table) - assert isinstance(simbad, SkyCoord) - assert len(results) > 0 - results.pprint_all(50) -#-------------------------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------------------------- def test_query_skymapper(): + # Coordinates around test-object (2021aess): + coo_centre = SkyCoord(ra=53.4505, dec=-19.495725, unit='deg', frame='icrs') - # Coordinates around test-object (2021aess): - coo_centre = SkyCoord( - ra=53.4505, - dec=-19.495725, - unit='deg', - frame='icrs' - ) + results, skymapper = catalogs.query_skymapper(coo_centre) - results, skymapper = catalogs.query_skymapper(coo_centre) + assert isinstance(results, Table) + assert isinstance(skymapper, SkyCoord) + assert len(results) > 0 + results.pprint_all(50) - assert isinstance(results, Table) - assert isinstance(skymapper, SkyCoord) - assert len(results) > 0 - results.pprint_all(50) -#-------------------------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------------------------- -@pytest.mark.parametrize('ra,dec', [ - [256.727512, 30.271482], # 2019yvr - [58.59512, -19.18172], # 2009D -]) +@pytest.mark.parametrize('ra,dec', [[256.727512, 30.271482], # 2019yvr + [58.59512, -19.18172], # 2009D + ]) def test_download_catalog(SETUP_CONFIG, ra, dec): - - # Check if CasJobs have been configured, and skip the entire test if it isn't. - # This has to be done like this, to avoid problems when config.ini doesn't exist. - try: - catalogs.configure_casjobs() - except catalogs.CasjobsError: - pytest.skip("CasJobs not configured") - - # Coordinates around test-object (2019yvr): - coo_centre = SkyCoord( - ra=ra, - dec=dec, - unit='deg', - frame='icrs' - ) - - tab = catalogs.query_all(coo_centre) - print(tab) - - assert isinstance(tab, Table), "Should return a Table" - results = catalogs.convert_table_to_dict(tab) - - assert isinstance(results, list), "Should return a list" - for obj in results: - assert isinstance(obj, dict), "Each element should be a dict" - - # Check columns: - assert 'starid' in obj and obj['starid'] > 0 - assert 'ra' in obj and np.isfinite(obj['ra']) - assert 'decl' in obj and np.isfinite(obj['decl']) - assert 'pm_ra' in obj - assert 'pm_dec' in obj - assert 'gaia_mag' in obj - assert 'gaia_bp_mag' in obj - assert 'gaia_rp_mag' in obj - assert 'gaia_variability' in obj - assert 'B_mag' in obj - assert 'V_mag' in obj - assert 'u_mag' in obj - assert 'g_mag' in obj - assert 'r_mag' in obj - assert 'i_mag' in obj - assert 'z_mag' in obj - assert 'H_mag' in obj - assert 'J_mag' in obj - assert 'K_mag' in obj - - # All values should be finite number or None: - for key, val in obj.items(): - if key not in ('starid', 'gaia_variability'): - assert val is None or np.isfinite(val), f"{key} is not a valid value: {val}" - - # TODO: Manually check a target from this position if the merge is correct. - #assert False - -#-------------------------------------------------------------------------------------------------- + # Check if CasJobs have been configured, and skip the entire test if it isn't. + # This has to be done like this, to avoid problems when config.ini doesn't exist. + try: + catalogs.configure_casjobs() + except catalogs.CasjobsError: + pytest.skip("CasJobs not configured") + + # Coordinates around test-object (2019yvr): + coo_centre = SkyCoord(ra=ra, dec=dec, unit='deg', frame='icrs') + + tab = catalogs.query_all(coo_centre) + print(tab) + + assert isinstance(tab, Table), "Should return a Table" + results = catalogs.convert_table_to_dict(tab) + + assert isinstance(results, list), "Should return a list" + for obj in results: + assert isinstance(obj, dict), "Each element should be a dict" + + # Check columns: + assert 'starid' in obj and obj['starid'] > 0 + assert 'ra' in obj and np.isfinite(obj['ra']) + assert 'decl' in obj and np.isfinite(obj['decl']) + assert 'pm_ra' in obj + assert 'pm_dec' in obj + assert 'gaia_mag' in obj + assert 'gaia_bp_mag' in obj + assert 'gaia_rp_mag' in obj + assert 'gaia_variability' in obj + assert 'B_mag' in obj + assert 'V_mag' in obj + assert 'u_mag' in obj + assert 'g_mag' in obj + assert 'r_mag' in obj + assert 'i_mag' in obj + assert 'z_mag' in obj + assert 'H_mag' in obj + assert 'J_mag' in obj + assert 'K_mag' in obj + + # All values should be finite number or None: + for key, val in obj.items(): + if key not in ('starid', 'gaia_variability'): + assert val is None or np.isfinite(val), f"{key} is not a valid value: {val}" + + +# TODO: Manually check a target from this position if the merge is correct. +# assert False + +# -------------------------------------------------------------------------------------------------- if __name__ == '__main__': - pytest.main([__file__]) + pytest.main([__file__]) diff --git a/tests/test_load_image.py b/tests/test_load_image.py index 3dd3927..2d11c21 100644 --- a/tests/test_load_image.py +++ b/tests/test_load_image.py @@ -12,55 +12,52 @@ from astropy.wcs import WCS from astropy.coordinates import SkyCoord import os.path -import conftest # noqa: F401 +import conftest # noqa: F401 from flows.api import get_filters from flows.load_image import load_image -#-------------------------------------------------------------------------------------------------- -@pytest.mark.parametrize('fpath,siteid', [ - ['SN2020aatc_K_20201213_495s.fits.gz', 13], - ['ADP.2021-10-15T11_40_06.553.fits.gz', 2], - #['TJO2459406.56826_V_imc.fits.gz', 22], - #['lsc1m009-fa04-20210704-0044-e91_v1.fits.gz', 4], - #['SN2021rcp_59409.931159242_B.fits.gz', 22], - #['SN2021rhu_59465.86130221_B.fits.gz', 22], - #['20200613_SN2020lao_u_stacked_meandiff.fits.gz', 1], - #['2021aess_20220104_K.fits.gz', 5], - #['2021aess_B01_20220207v1.fits.gz', 5], -]) -def test_load_image(fpath, siteid): - # Get list of all available filters: - all_filters = set(get_filters().keys()) - - # The test input directory containing the test-images: - INPUT_DIR = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'input') - - # Target coordinates, only used for HAWKI image: - target_coord = SkyCoord( - ra=347.6230189, - dec=7.5888196, - unit='deg', - frame='icrs') - # Load the image from the test-set: - img = load_image(os.path.join(INPUT_DIR, fpath), target_coord=target_coord) - - # Check the attributes of the image object: - assert isinstance(img.image, np.ndarray) - assert img.image.dtype in ('float32', 'float64') - assert isinstance(img.mask, np.ndarray) - assert img.mask.dtype == 'bool' - assert isinstance(img.clean, np.ma.MaskedArray) - assert img.clean.dtype == img.image.dtype - assert isinstance(img.obstime, Time) - assert isinstance(img.exptime, float) - assert img.exptime > 0 - assert isinstance(img.wcs, WCS) - assert isinstance(img.site, dict) - assert img.site['siteid'] == siteid - assert isinstance(img.photfilter, str) - assert img.photfilter in all_filters - -#-------------------------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------------------------- +@pytest.mark.parametrize('fpath,siteid', + [['SN2020aatc_K_20201213_495s.fits.gz', 13], ['ADP.2021-10-15T11_40_06.553.fits.gz', 2], + # ['TJO2459406.56826_V_imc.fits.gz', 22], + # ['lsc1m009-fa04-20210704-0044-e91_v1.fits.gz', 4], + # ['SN2021rcp_59409.931159242_B.fits.gz', 22], + # ['SN2021rhu_59465.86130221_B.fits.gz', 22], + # ['20200613_SN2020lao_u_stacked_meandiff.fits.gz', 1], + # ['2021aess_20220104_K.fits.gz', 5], + # ['2021aess_B01_20220207v1.fits.gz', 5], + ]) +def test_load_image(fpath, siteid): + # Get list of all available filters: + all_filters = set(get_filters().keys()) + + # The test input directory containing the test-images: + INPUT_DIR = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'input') + + # Target coordinates, only used for HAWKI image: + target_coord = SkyCoord(ra=347.6230189, dec=7.5888196, unit='deg', frame='icrs') + + # Load the image from the test-set: + img = load_image(os.path.join(INPUT_DIR, fpath), target_coord=target_coord) + + # Check the attributes of the image object: + assert isinstance(img.image, np.ndarray) + assert img.image.dtype in ('float32', 'float64') + assert isinstance(img.mask, np.ndarray) + assert img.mask.dtype == 'bool' + assert isinstance(img.clean, np.ma.MaskedArray) + assert img.clean.dtype == img.image.dtype + assert isinstance(img.obstime, Time) + assert isinstance(img.exptime, float) + assert img.exptime > 0 + assert isinstance(img.wcs, WCS) + assert isinstance(img.site, dict) + assert img.site['siteid'] == siteid + assert isinstance(img.photfilter, str) + assert img.photfilter in all_filters + + +# -------------------------------------------------------------------------------------------------- if __name__ == '__main__': - pytest.main([__file__]) + pytest.main([__file__]) diff --git a/tests/test_photometry.py b/tests/test_photometry.py index 15b379f..3bab9cb 100644 --- a/tests/test_photometry.py +++ b/tests/test_photometry.py @@ -7,14 +7,17 @@ """ import pytest -import conftest # noqa: F401 -from flows import photometry # noqa: F401 +import conftest # noqa: F401 +from flows import photometry # noqa: F401 -#-------------------------------------------------------------------------------------------------- + +# -------------------------------------------------------------------------------------------------- def test_import_photometry(): - pass - #assert photometry + pass + + +# assert photometry -#-------------------------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------------------------- if __name__ == '__main__': - pytest.main([__file__]) + pytest.main([__file__]) diff --git a/tests/test_tns.py b/tests/test_tns.py index 98103c5..82c8d62 100644 --- a/tests/test_tns.py +++ b/tests/test_tns.py @@ -13,80 +13,72 @@ from conftest import capture_cli from flows import tns -#-------------------------------------------------------------------------------------------------- -@pytest.mark.skipif(os.environ.get('CI') == 'true', reason="Disabled on GitHub Actions to avoid too many requests HTTP error") + +# -------------------------------------------------------------------------------------------------- +@pytest.mark.skipif(os.environ.get('CI') == 'true', + reason="Disabled on GitHub Actions to avoid too many requests HTTP error") def test_tns_search(SETUP_CONFIG): + coo_centre = SkyCoord(ra=191.283890127, dec=-0.45909033652, unit='deg', frame='icrs') + res = tns.tns_search(coo_centre) - coo_centre = SkyCoord( - ra=191.283890127, - dec=-0.45909033652, - unit='deg', - frame='icrs' - ) - res = tns.tns_search(coo_centre) + print(res) + assert res[0]['objname'] == '2019yvr' + assert res[0]['prefix'] == 'SN' - print(res) - assert res[0]['objname'] == '2019yvr' - assert res[0]['prefix'] == 'SN' -#-------------------------------------------------------------------------------------------------- -@pytest.mark.skipif(os.environ.get('CI') == 'true', reason="Disabled on GitHub Actions to avoid too many requests HTTP error") +# -------------------------------------------------------------------------------------------------- +@pytest.mark.skipif(os.environ.get('CI') == 'true', + reason="Disabled on GitHub Actions to avoid too many requests HTTP error") def test_tns_get_obj(SETUP_CONFIG): + res = tns.tns_get_obj('2019yvr') - res = tns.tns_get_obj('2019yvr') + print(res) + assert res['objname'] == '2019yvr' + assert res['name_prefix'] == 'SN' - print(res) - assert res['objname'] == '2019yvr' - assert res['name_prefix'] == 'SN' -#-------------------------------------------------------------------------------------------------- -@pytest.mark.skipif(os.environ.get('CI') == 'true', reason="Disabled on GitHub Actions to avoid too many requests HTTP error") +# -------------------------------------------------------------------------------------------------- +@pytest.mark.skipif(os.environ.get('CI') == 'true', + reason="Disabled on GitHub Actions to avoid too many requests HTTP error") def test_tns_get_obj_noexist(SETUP_CONFIG): - res = tns.tns_get_obj('1892doesnotexist') - print(res) - assert res is None - -#-------------------------------------------------------------------------------------------------- -@pytest.mark.skipif(os.environ.get('CI') == 'true', reason="Disabled on GitHub Actions to avoid too many requests HTTP error") -@pytest.mark.parametrize('date_begin,date_end', [ - ('2019-01-01', '2019-02-01'), - (datetime.date(2019, 1, 1), datetime.date(2019, 2, 1)), - (datetime.datetime(2019, 1, 1, 12, 0), datetime.datetime(2019, 2, 1, 12, 0)) -]) + res = tns.tns_get_obj('1892doesnotexist') + print(res) + assert res is None + + +# -------------------------------------------------------------------------------------------------- +@pytest.mark.skipif(os.environ.get('CI') == 'true', + reason="Disabled on GitHub Actions to avoid too many requests HTTP error") +@pytest.mark.parametrize('date_begin,date_end', + [('2019-01-01', '2019-02-01'), (datetime.date(2019, 1, 1), datetime.date(2019, 2, 1)), + (datetime.datetime(2019, 1, 1, 12, 0), datetime.datetime(2019, 2, 1, 12, 0))]) def test_tns_getnames(SETUP_CONFIG, date_begin, date_end): + names = tns.tns_getnames(date_begin=date_begin, date_end=date_end, zmin=0, zmax=0.105, objtype=3) + + print(names) + assert isinstance(names, list), "Should return a list" + for n in names: + assert isinstance(n, str), "Each element should be a string" + assert n.startswith('SN'), "All names should begin with 'SN'" + assert 'SN2019A' in names, "SN2019A should be in the list" - names = tns.tns_getnames( - date_begin=date_begin, - date_end=date_end, - zmin=0, - zmax=0.105, - objtype=3 - ) - - print(names) - assert isinstance(names, list), "Should return a list" - for n in names: - assert isinstance(n, str), "Each element should be a string" - assert n.startswith('SN'), "All names should begin with 'SN'" - assert 'SN2019A' in names, "SN2019A should be in the list" - -#-------------------------------------------------------------------------------------------------- + +# -------------------------------------------------------------------------------------------------- def test_tns_getnames_wronginput(SETUP_CONFIG): - # Wrong dates should result in ValueError: - with pytest.raises(ValueError): - tns.tns_getnames( - date_begin=datetime.date(2019, 1, 1), - date_end=datetime.date(2017, 1, 1) - ) - -#-------------------------------------------------------------------------------------------------- -@pytest.mark.skipif(os.environ.get('CI') == 'true', reason="Disabled on GitHub Actions to avoid too many requests HTTP error") + # Wrong dates should result in ValueError: + with pytest.raises(ValueError): + tns.tns_getnames(date_begin=datetime.date(2019, 1, 1), date_end=datetime.date(2017, 1, 1)) + + +# -------------------------------------------------------------------------------------------------- +@pytest.mark.skipif(os.environ.get('CI') == 'true', + reason="Disabled on GitHub Actions to avoid too many requests HTTP error") def test_run_querytns(SETUP_CONFIG): + # Run the command line interface: + out, err, exitcode = capture_cli('run_querytns.py') + assert exitcode == 0 - # Run the command line interface: - out, err, exitcode = capture_cli('run_querytns.py') - assert exitcode == 0 -#-------------------------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------------------------- if __name__ == '__main__': - pytest.main([__file__]) + pytest.main([__file__]) diff --git a/tests/test_ztf.py b/tests/test_ztf.py index fa0c550..0c1621f 100644 --- a/tests/test_ztf.py +++ b/tests/test_ztf.py @@ -16,71 +16,60 @@ from conftest import capture_cli from flows import ztf -#-------------------------------------------------------------------------------------------------- + +# -------------------------------------------------------------------------------------------------- def test_ztf_id(): + coo_centre = SkyCoord(ra=191.283890127, dec=-0.45909033652, unit='deg', frame='icrs') + ztfid = ztf.query_ztf_id(coo_centre) + assert ztfid == 'ZTF20aabqkxs' + + # With the correct discovery date we should get the same result: + dd = Time('2019-12-27 12:30:14', format='iso', scale='utc') + ztfid = ztf.query_ztf_id(coo_centre, discovery_date=dd) + assert ztfid == 'ZTF20aabqkxs' + + # With a wrong discovery date, we should not get a ZTF id: + dd = Time('2021-12-24 18:00:00', format='iso', scale='utc') + ztfid = ztf.query_ztf_id(coo_centre, discovery_date=dd) + assert ztfid is None + + coo_centre = SkyCoord(ra=181.6874198, dec=67.1649528, unit='deg', frame='icrs') + ztfid = ztf.query_ztf_id(coo_centre) + assert ztfid == 'ZTF21aatyplr' + - coo_centre = SkyCoord( - ra=191.283890127, - dec=-0.45909033652, - unit='deg', - frame='icrs' - ) - ztfid = ztf.query_ztf_id(coo_centre) - assert ztfid == 'ZTF20aabqkxs' - - # With the correct discovery date we should get the same result: - dd = Time('2019-12-27 12:30:14', format='iso', scale='utc') - ztfid = ztf.query_ztf_id(coo_centre, discovery_date=dd) - assert ztfid == 'ZTF20aabqkxs' - - # With a wrong discovery date, we should not get a ZTF id: - dd = Time('2021-12-24 18:00:00', format='iso', scale='utc') - ztfid = ztf.query_ztf_id(coo_centre, discovery_date=dd) - assert ztfid is None - - coo_centre = SkyCoord( - ra=181.6874198, - dec=67.1649528, - unit='deg', - frame='icrs' - ) - ztfid = ztf.query_ztf_id(coo_centre) - assert ztfid == 'ZTF21aatyplr' - -#-------------------------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------------------------- @pytest.mark.parametrize('targetid', [2, 865]) def test_ztf_photometry(SETUP_CONFIG, targetid): + tab = ztf.download_ztf_photometry(targetid) + print(tab) - tab = ztf.download_ztf_photometry(targetid) - print(tab) + assert isinstance(tab, Table) + assert 'time' in tab.colnames + assert 'photfilter' in tab.colnames + assert 'mag' in tab.colnames + assert 'mag_err' in tab.colnames + assert np.all(np.isfinite(tab['time'])) + assert np.all(np.isfinite(tab['mag'])) + assert np.all(np.isfinite(tab['mag_err'])) - assert isinstance(tab, Table) - assert 'time' in tab.colnames - assert 'photfilter' in tab.colnames - assert 'mag' in tab.colnames - assert 'mag_err' in tab.colnames - assert np.all(np.isfinite(tab['time'])) - assert np.all(np.isfinite(tab['mag'])) - assert np.all(np.isfinite(tab['mag_err'])) -#-------------------------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------------------------- @pytest.mark.parametrize('targetid', [2, 865]) def test_run_download_ztf(targetid): - with tempfile.TemporaryDirectory() as tmpdir: - # Nothing exists before running: - assert len(os.listdir(tmpdir)) == 0 - - # Run the command line interface: - out, err, exitcode = capture_cli('run_download_ztf.py', [ - f'--target={targetid:d}', - '-o', tmpdir - ]) - assert exitcode == 0 - - # The output directory should now have two files: - print(os.listdir(tmpdir)) - assert len(os.listdir(tmpdir)) == 2 - -#-------------------------------------------------------------------------------------------------- + with tempfile.TemporaryDirectory() as tmpdir: + # Nothing exists before running: + assert len(os.listdir(tmpdir)) == 0 + + # Run the command line interface: + out, err, exitcode = capture_cli('run_download_ztf.py', [f'--target={targetid:d}', '-o', tmpdir]) + assert exitcode == 0 + + # The output directory should now have two files: + print(os.listdir(tmpdir)) + assert len(os.listdir(tmpdir)) == 2 + + +# -------------------------------------------------------------------------------------------------- if __name__ == '__main__': - pytest.main([__file__]) + pytest.main([__file__]) From 21a3bee4271164bda25a04e9a42b4160cfc988a0 Mon Sep 17 00:00:00 2001 From: Emir Date: Mon, 14 Feb 2022 22:15:39 +0100 Subject: [PATCH 2/3] change editorconfig to space --- .editorconfig | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.editorconfig b/.editorconfig index 2d1324b..5349964 100644 --- a/.editorconfig +++ b/.editorconfig @@ -6,7 +6,7 @@ charset = utf-8 # Python source files [*.py] -indent_style = tab +indent_style = space indent_size = 4 trim_trailing_whitespace = true insert_final_newline = true From 683801e9200c04829eba87e77ba519e565a73236 Mon Sep 17 00:00:00 2001 From: Emir Karamehmetoglu Date: Thu, 3 Mar 2022 13:12:59 +0100 Subject: [PATCH 3/3] fix config.ini --- setup.cfg | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index 8c93ff0..094903d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,7 +1,9 @@ [flake8] exclude = .git,__pycache__,notes -max-line-length = 120 # To be compliant with black -extend-ignore = E203 # To be compliant with black +# To be compliant with black +max-line-length = 120 +#To be compliant with black +extend-ignore = E203 # Enable flake8-logging-format: enable-extensions = G