# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""Synthetic photometry utility functions."""
# STDLIB
import os
from shutil import copyfile
# THIRD-PARTY
import numpy as np
# ASTROPY
from astropy import units as u
from astropy.config import ConfigItem
from astropy.utils.data import download_file
# LOCAL
from synphot import exceptions, units
__all__ = ['overlap_status', 'validate_totalflux', 'validate_wavelengths',
'generate_wavelengths', 'merge_wavelengths', 'download_data']
[docs]
def overlap_status(a, b):
"""Check overlap between two arrays.
Parameters
----------
a, b : array-like
Arrays to check. Assumed to be in the same unit.
Returns
-------
result : {'full', 'partial', 'none'}
* 'full' - ``a`` is within or same as ``b``
* 'partial' - ``a`` partially overlaps with ``b``
* 'none' - ``a`` does not overlap ``b``
"""
# Get the endpoints
a1, a2 = a.min(), a.max()
b1, b2 = b.min(), b.max()
# Do the comparison
if a1 >= b1 and a2 <= b2:
result = 'full'
elif a2 < b1 or b2 < a1:
result = 'none'
else:
result = 'partial'
return result
[docs]
def validate_totalflux(totalflux):
"""Check integrated flux for invalid values.
Parameters
----------
totalflux : float
Integrated flux.
Raises
------
synphot.exceptions.SynphotError
Input is zero, negative, or not a number.
"""
if totalflux <= 0.0:
raise exceptions.SynphotError('Integrated flux is <= 0')
elif np.isnan(totalflux):
raise exceptions.SynphotError('Integrated flux is NaN')
elif np.isinf(totalflux):
raise exceptions.SynphotError('Integrated flux is infinite')
[docs]
def validate_wavelengths(wavelengths):
"""Check wavelengths for ``synphot`` compatibility.
Wavelengths must satisfy these conditions:
* valid unit type, if given
* no zeroes
* monotonic ascending or descending
* no duplicate values
Parameters
----------
wavelengths : array-like or `~astropy.units.quantity.Quantity`
Wavelength values.
Raises
------
synphot.exceptions.SynphotError
Wavelengths unit type is invalid.
synphot.exceptions.DuplicateWavelength
Wavelength array contains duplicate entries.
synphot.exceptions.UnsortedWavelength
Wavelength array is not monotonic.
synphot.exceptions.ZeroWavelength
Negative or zero wavelength occurs in wavelength array.
"""
if isinstance(wavelengths, u.Quantity):
units.validate_wave_unit(wavelengths.unit)
wave = wavelengths.value
else:
wave = wavelengths
if np.isscalar(wave):
wave = [wave]
wave = np.asarray(wave)
# Check for zeroes
if np.any(wave <= 0):
raise exceptions.ZeroWavelength(
'Negative or zero wavelength occurs in wavelength array',
rows=np.where(wave <= 0)[0])
# Check for monotonicity
sorted_wave = np.sort(wave)
if not np.all(sorted_wave == wave):
if np.all(sorted_wave[::-1] == wave):
pass # Monotonic descending is allowed
else:
raise exceptions.UnsortedWavelength(
'Wavelength array is not monotonic',
rows=np.where(sorted_wave != wave)[0])
# Check for duplicate values
if wave.size > 1:
dw = sorted_wave[1:] - sorted_wave[:-1]
if np.any(dw == 0):
raise exceptions.DuplicateWavelength(
'Wavelength array contains duplicate entries',
rows=np.where(dw == 0)[0])
[docs]
def generate_wavelengths(minwave=500, maxwave=26000, num=10000, delta=None,
log=True, wave_unit=u.AA):
"""Generate wavelength array to be used for spectrum sampling.
.. math::
minwave \\le \\lambda < maxwave
Parameters
----------
minwave, maxwave : float
Lower and upper limits of the wavelengths.
These must be values in linear space regardless of ``log``.
num : int
The number of wavelength values.
This is only used when ``delta=None``.
delta : float or `None`
Delta between wavelength values.
When ``log=True``, this is the spacing in log space.
log : bool
If `True`, the wavelength values are evenly spaced in log scale.
Otherwise, spacing is linear.
wave_unit : str or `~astropy.units.Unit`
Wavelength unit. Default is Angstrom.
Returns
-------
waveset : `~astropy.units.quantity.Quantity`
Generated wavelength set.
waveset_str : str
Info string associated with the result.
"""
wave_unit = units.validate_unit(wave_unit)
if delta is not None:
num = None
waveset_str = 'Min: {0}, Max: {1}, Num: {2}, Delta: {3}, Log: {4}'.format(
minwave, maxwave, num, delta, log)
# Log space
if log:
logmin = np.log10(minwave)
logmax = np.log10(maxwave)
if delta is None:
waveset = np.logspace(logmin, logmax, num, endpoint=False)
else:
waveset = 10 ** np.arange(logmin, logmax, delta)
# Linear space
else:
if delta is None:
waveset = np.linspace(minwave, maxwave, num, endpoint=False)
else:
waveset = np.arange(minwave, maxwave, delta)
return waveset.astype(np.float64) * wave_unit, waveset_str
[docs]
def merge_wavelengths(waveset1, waveset2, threshold=1e-12):
"""Return the union of the two sets of wavelengths using
:func:`numpy.union1d`.
The merged wavelengths may sometimes contain numbers which are nearly
equal but differ at levels as small as 1e-14. Having values this
close together can cause problems down the line. So, here we test
whether any such small differences are present, with a small
difference defined as less than ``threshold``. If a small
difference is present, the lower of the too-close pair is removed.
Parameters
----------
waveset1, waveset2 : array-like or `None`
Wavelength values, assumed to be in the same unit already.
Also see :func:`~synphot.models.get_waveset`.
threshold : float, optional
Merged wavelength values are considered "too close together"
when the difference is smaller than this number.
The default is 1e-12.
Returns
-------
out_wavelengths : array-like or `None`
Merged wavelengths. `None` if undefined.
"""
if waveset1 is None and waveset2 is None:
out_wavelengths = None
elif waveset1 is not None and waveset2 is None:
out_wavelengths = waveset1
elif waveset1 is None and waveset2 is not None:
out_wavelengths = waveset2
else:
out_wavelengths = np.union1d(waveset1, waveset2)
delta = out_wavelengths[1:] - out_wavelengths[:-1]
i_good = np.where(delta > threshold)
# Remove "too close together" duplicates
if len(i_good[0]) < delta.size:
out_wavelengths = np.append(
out_wavelengths[i_good], out_wavelengths[-1])
return out_wavelengths
[docs]
def download_data(path_root, verbose=True, dry_run=False):
"""Download ``synphot`` data files to given root directory or
the ``astropy`` cache.
Download is skipped if a data file already exists.
.. warning::
Downloading data to ``astropy`` cache only is not recommended
if you plan to provide a custom ``synphot.cfg``.
Parameters
----------
path_root : str or `None`
Root directory for data files. If `None`, download to the ``astropy``
cache location instead of a specific directory.
verbose : bool
Print extra information to screen.
dry_run : bool
Go through the logic but skip the actual download.
This would return a list of files that *would have been*
downloaded without network calls. The sub-directories
would still be created regardless.
Use this option for debugging or testing.
Raises
------
OSError
Problem with directory.
Returns
-------
file_list : list of str
A list of downloaded files.
"""
from synphot.config import conf # Avoid potential circular import
BASE_HOST = 'https://ssb.stsci.edu/trds/'
if path_root is not None:
if not os.path.exists(path_root):
os.makedirs(path_root, exist_ok=True)
if verbose: # pragma: no cover
print('Created {}'.format(path_root))
elif not os.path.isdir(path_root):
raise OSError('{} must be a directory'.format(path_root))
if not path_root.endswith(os.sep):
path_root += os.sep
file_list = []
# See https://github.com/astropy/astropy/issues/8524
for cfgitem in conf.__class__.__dict__.values():
if (not isinstance(cfgitem, ConfigItem) or
not cfgitem.name.endswith('file')):
continue
url = cfgitem.defaultvalue
if not url.startswith(BASE_HOST):
if verbose: # pragma: no cover
print('{} is not from {}, skipping download'.format(
url, BASE_HOST))
continue
if path_root is not None:
dst = url.replace(BASE_HOST, path_root).replace('/', os.sep)
if os.path.exists(dst):
if verbose: # pragma: no cover
print('{} already exists, skipping download'.format(dst))
continue
# Create sub-directories, if needed.
subdirs = os.path.dirname(dst)
os.makedirs(subdirs, exist_ok=True)
if not dry_run: # pragma: no cover
try:
src = download_file(url, cache=True)
if path_root is not None:
copyfile(src, dst)
except Exception as exc:
print('Download failed - {}'.format(str(exc)))
continue
if path_root is None:
if dry_run:
file_list.append(url)
else: # pragma: no cover
file_list.append(src)
else:
file_list.append(dst)
if verbose: # pragma: no cover
print('{} downloaded to {}'.format(url, file_list[-1]))
return file_list