#%%
import inspect
import os
from pathlib import Path
from multiprocessing import Pool
import numpy as np
import astropy.units as u
from astropy.coordinates import Longitude, Latitude, Angle, SkyCoord
from astropy.table import Table
from astropy.time import Time
from astropy.wcs import WCS
from astroquery.hips2fits import hips2fits
from astroquery.mocserver import MOCServer
from matplotlib import cm
from ezphot.imageobjects import ScienceImage
#%%
hips2fits.timeout = 300
class HIPS2FITS:
"""
A class to handle image queries.
"""
def __init__(self, catalog_key: str = None):
"""
Initializes the ImageQuerier with the path to the image.
:param image_path: Path to the image file.
"""
if catalog_key is not None:
if catalog_key not in self.catalog_ids.keys():
raise ValueError(f"Catalog Key '{catalog_key}' is not recognized. Available keys: {list(self.catalog_ids.keys())}")
self.current_catalog_key = catalog_key
@property
def config(self):
class Configuration:
"""Handles configuration for Vizier queries."""
@property
def projection(self):
return 'TAN'
@projection.setter
def projection(self, value):
self.projection = value
@property
def coordsys(self):
return 'icrs'
@coordsys.setter
def coordsys(self, value):
self.coordsys = value
@property
def format(self):
return 'fits'
@format.setter
def format(self, value):
self.format = value
@property
def stretch(self):
return 'linear'
@stretch.setter
def stretch(self, value):
self.stretch = value
@property
def cmap(self):
return 'Greys_r'
@cmap.setter
def cmap(self, value):
self.cmap = value
@property
def min_cut(self):
return 0.5
@min_cut.setter
def min_cut(self, value):
self.min_cut = value
@property
def max_cut(self):
return 99.5
@max_cut.setter
def max_cut(self, value):
self.max_cut = value
def __repr__(self):
return (f"========HIPS2FITS Configuration========\n"
f" projection = {self.projection}\n"
f" coordsys = {self.coordsys}\n"
f" format = {self.format}\n"
f" stretch = {self.stretch}\n"
f" cmap = {self.cmap}\n"
f" min_cut = {self.min_cut}\n"
f" max_cut = {self.max_cut}\n"
f"====================================")
return Configuration()
@property
def catalog_ids(self):
catalog_ids = dict()
# SkyMapper DR4
catalog_ids['SkyMapper/SMSS4/g'] = 'CDS/P/Skymapper/DR4/g'
catalog_ids['SkyMapper/SMSS4/r'] = 'CDS/P/Skymapper/DR4/r'
catalog_ids['SkyMapper/SMSS4/i'] = 'CDS/P/Skymapper/DR4/i'
# Skymapper DR1
catalog_ids['SkyMapper/SMSS1/u'] = 'CDS/P/Skymapper-U'
catalog_ids['SkyMapper/SMSS1/g'] = 'CDS/P/Skymapper-G'
catalog_ids['SkyMapper/SMSS1/v'] = 'CDS/P/Skymapper-V'
catalog_ids['SkyMapper/SMSS1/r'] = 'CDS/P/Skymapper-R'
catalog_ids['SkyMapper/SMSS1/i'] = 'CDS/P/Skymapper-I'
catalog_ids['SkyMapper/SMSS1/z'] = 'CDS/P/Skymapper-Z'
# Pan-STARRS DR1
catalog_ids['PanSTARRS/PS1/g'] = "CDS/P/PanSTARRS/DR1/g"
catalog_ids['PanSTARRS/PS1/r'] = "CDS/P/PanSTARRS/DR1/r"
catalog_ids['PanSTARRS/PS1/i'] = "CDS/P/PanSTARRS/DR1/i"
catalog_ids['PanSTARRS/PS1/z'] = "CDS/P/PanSTARRS/DR1/z"
catalog_ids['PanSTARRS/PS1/y'] = "CDS/P/PanSTARRS/DR1/y"
# SDSS DR9
catalog_ids['SDSS/SDSS9/u'] = "CDS/P/SDSS9/u"
catalog_ids['SDSS/SDSS9/g'] = "CDS/P/SDSS9/g"
catalog_ids['SDSS/SDSS9/r'] = "CDS/P/SDSS9/r"
catalog_ids['SDSS/SDSS9/i'] = "CDS/P/SDSS9/i"
catalog_ids['SDSS/SDSS9/z'] = "CDS/P/SDSS9/z"
# DESI Legacy Imaging Survey
catalog_ids['DESI/DESI/g'] = "CDS/P/DESI-Legacy-Surveys/DR10/g"
catalog_ids['DESI/DESI/r'] = "CDS/P/DESI-Legacy-Surveys/DR10/r"
catalog_ids['DESI/DESI/i'] = "CDS/P/DESI-Legacy-Surveys/DR10/i"
catalog_ids['DESI/DESI/z'] = "CDS/P/DESI-Legacy-Surveys/DR10/z"
# DSS
catalog_ids['DSS/DSS2/b'] = "CDS/P/DSS2/blue"
catalog_ids['DSS/DSS2/r'] = "CDS/P/DSS2/red"
# catalog_ids['DSSDSS2/nir'] = "CDS/P/DSS2/NIR"
# ZTF
catalog_ids['ZTF/ZTF7/g'] = "CDS/P/ZTF/DR7/g"
catalog_ids['ZTF/ZTF7/r'] = "CDS/P/ZTF/DR7/r"
catalog_ids['ZTF/ZTF7/i'] = "CDS/P/ZTF/DR7/i"
# DECAM
catalog_ids['DECALS/DEC5/g'] = "CDS/P/DECaLS/DR5/g"
catalog_ids['DECALS/DEC5/r'] = "CDS/P/DECaLS/DR5/r"
# DES
catalog_ids['DES/DES2/g'] = "CDS/P/DES-DR2/g"
catalog_ids['DES/DES2/r'] = "CDS/P/DES-DR2/r"
catalog_ids['DES/DES2/i'] = "CDS/P/DES-DR2/i"
catalog_ids['DES/DES2/z'] = "CDS/P/DES-DR2/z"
catalog_ids['DES/DES2/Y'] = "CDS/P/DES-DR2/Y"
return catalog_ids
def _change_catalog(self, catalog_key):
if catalog_key in self.catalog_ids.keys():
self.current_catalog_key = catalog_key
print(self.__repr__())
else:
raise ValueError(f"Catalog Key '{catalog_key}' is not recognized.")
def _show_available_catalogs(self):
print("Current catalog: ", self.current_catalog_key)
print("Available catalogs\n==================")
for catalog_name, catalog_id in self.catalog_ids.items():
print(f"{catalog_name}: {catalog_id}")
return list(self.catalog_ids.keys())
def _check_catalog_coverages(self,
ra: float,
dec: float,
radius_deg: float = 0.1,
verbose: bool = True,
search_all: bool = False) -> dict:
from regions import CircleSkyRegion
coord = SkyCoord(ra=ra * u.deg, dec=dec * u.deg)
region = CircleSkyRegion(center=coord, radius=radius_deg * u.deg)
results = {}
if not search_all:
try:
catalog_key = self.current_catalog_key
hips_id = self.catalog_ids[catalog_key]
query = MOCServer.query_region(
region=region,
criteria=f"ID={hips_id}",
intersect="overlaps",
max_rec=1
)
has_coverage = len(query) > 0
except Exception as e:
has_coverage = False
if verbose:
print(f"[ERROR] Failed to check {catalog_key}: {e}")
results[catalog_key] = has_coverage
if verbose:
print(f"[{catalog_key}] {'✓' if has_coverage else '✗'}")
else:
for catalog_key, hips_id in self.catalog_ids.items():
try:
query = MOCServer.query_region(
region=region,
criteria=f"ID={hips_id}",
intersect="overlaps",
max_rec=1
)
has_coverage = len(query) > 0
except Exception as e:
has_coverage = False
if verbose:
print(f"[ERROR] Failed to check {catalog_key}: {e}")
results[catalog_key] = has_coverage
if verbose:
print(f"[{catalog_key}] {'✓' if has_coverage else '✗'}")
return results
def _query(self,
wcs: WCS = None, # If wcs inputted, overrides ra, dec, fov, rotation_angle
width: int = 2000,
height: int = 2000,
ra: float = 0.0,
dec: float = 0.0,
fov: float = 5.0,
rotation_angle: float = 0.0,
save_path: str = None,
verbose: bool = False,
):
if save_path is None:
save_path = os.path.join(os.getcwd(), f"hips2fits_{self.current_catalog_key}_{ra}_{dec}.fits")
if verbose:
print(f"Default save path: {save_path}")
if wcs is not None:
# If WCS is provided, use it to query the image
result = hips2fits.query_with_wcs(
hips=self.catalog_ids[self.current_catalog_key],
wcs=wcs,
format=self.config.format,
min_cut=self.config.min_cut,
max_cut=self.config.max_cut,
stretch=self.config.stretch,
cmap=cm.get_cmap(self.config.cmap),
verbose=verbose,
)
else:
result = hips2fits.query(
hips=self.catalog_ids[self.current_catalog_key],
width=width,
height=height,
projection = self.config.projection,
fov=Angle(fov * u.deg),
ra=Longitude(ra * u.deg),
dec=Latitude(dec * u.deg),
coordsys=self.config.coordsys,
rotation_angle=Angle(rotation_angle * u.deg),
format=self.config.format,
min_cut=self.config.min_cut,
max_cut=self.config.max_cut,
stretch=self.config.stretch,
cmap=cm.get_cmap(self.config.cmap),
verbose=verbose,
)
if verbose:
print(f"Saved: {save_path}")
result[0].writeto(save_path, overwrite=True)
return save_path
[docs]
class ImageQuerier(HIPS2FITS):
"""
A class to handle image queries using HiPS2FITS.
Inherits from HIPS2FITS to utilize its methods and properties.
"""
[docs]
def __init__(self, catalog_key: str = None):
"""
Initializes the ImageQuerier with the specified catalog key.
:param catalog_key: Key for the catalog to query.
"""
super().__init__(catalog_key=catalog_key)
def __repr__(self):
return f"ImageQuerier(catalog={self.current_catalog_key})\n{self.config}\n For help, use 'help(self)' or `self.help()`."
def help(self):
# Get all public methods from the class, excluding `help`
methods = [
(name, obj)
for name, obj in inspect.getmembers(self.__class__, inspect.isfunction)
if not name.startswith("_") and name != "help"
]
# Build plain text list with parameters
lines = []
for name, func in methods:
sig = inspect.signature(func)
params = [str(p) for p in sig.parameters.values() if p.name != "self"]
sig_str = f"({', '.join(params)})" if params else "()"
lines.append(f"- {name}{sig_str}")
# Final plain text output
help_text = ""
print(f"Help for {self.__class__.__name__}\n{help_text}\n\nPublic methods:\n" + "\n".join(lines))
def query(self,
width: int,
height: int,
ra: float,
dec: float,
pixelscale: float,
telinfo: Table,
save_path: str = None,
objname: str = None,
rotation_angle: float = 0.0,
verbose: bool = True,
n_processes: int = 4):
"""
Run HiPS2FITS queries with split tiles using multiprocessing.
Returns
-------
list of astropy.io.fits.HDUList
"""
if self.current_catalog_key is None:
raise ValueError("No catalog key provided. Please set the catalog key using the change_catalog method.")
if self.check_coverage(ra, dec,search_all = False)[self.current_catalog_key] is False:
raise ValueError(f"No coverage found for {self.current_catalog_key} at {ra}, {dec}")
observatory, telname, filter_ = self.current_catalog_key.split('/')
fov = max(width, height) * pixelscale / 3600 # Convert pixel scale to degrees
tile_params = self._split_query_regions(
width=width,
height=height,
ra=ra,
dec=dec,
fov=fov,
rotation_angle=rotation_angle,
verbose=verbose
)
if save_path is None:
output_path = os.path.join(Path.home(), f"hips2fits_{observatory}_{telname}_{ra}_{dec}.fits")
else:
output_path = str(save_path)
if verbose:
print(f"[QUERY] Dispatching {len(tile_params)} tiles with {n_processes} processes")
with Pool(processes=n_processes) as pool:
tasks = [(param, output_path, verbose) for param in tile_params]
results = pool.map(self._query_tile_worker, tasks)
from ezphot.methods import Stack
self.stacking = Stack()
target_imglist = [ScienceImage(result, telinfo = telinfo, load = True) for result in results]
# # # Stacking
stack_instance, stack_weight_instance = self.stacking.stack_swarp(
target_imglist = target_imglist,
target_bkglist= None,
target_errormaplist = None,
target_outpath = output_path,
errormap_outpath = None,
combine_type = 'average',
resample = True,
resample_type = 'LANCZOS3',
center_ra = ra,
center_dec = dec,
pixel_scale = pixelscale,
x_size = width,
y_size = height,
scale = False,
scale_type = 'min',
zp_key = 'ZP_APER_1',
convolve = False,
seeing_key = 'SEEING',
kernel = 'gaussian',
save = True,
verbose = True
)
stack_instance.load()
stack_instance.remove()
stack_instance = stack_instance.to_referenceimage()
# Update stack_instance with metadata
update_header_kwargs = dict(
BINNING = 1,
TELNAME = telname,
FILTER = filter_,
TELESCOP = observatory,
OBJNAME = objname if objname else 'Unknown',
IMAGETYP = 'LIGHT',
OBSDATE = Time('2001-01-01T00:00:00').isot, # Placeholder date
SEEING = 2.0,
UL5SKY_APER_2 = 21.0
)
stack_instance.header.update(**update_header_kwargs)
if np.max(stack_instance.data) < 1e3:
stack_instance.data *= 1e3
if save_path is not None:
stack_instance.savedir = Path(save_path).parent
if verbose:
print(f'Save path: {stack_instance.savepath.savepath}')
else:
stack_instance.savedir = None
if verbose:
print(f'Default save path: {stack_instance.savepath.savepath}')
stack_instance.write()
for target_img in target_imglist:
target_img.remove()
return stack_instance
def check_coverage(self,
ra: float,
dec: float,
radius_deg: float = 1.0,
search_all: bool = False,
verbose: bool = True):
"""
Check if the current HiPS catalog has coverage at the given RA/Dec.
Parameters
----------
ra : float
Right Ascension in degrees
dec : float
Declination in degrees
radius_deg : float
Search radius in degrees
verbose : bool
If True, prints whether coverage exists
Returns
-------
coverage_dict : dict
Dictionary of catalog_key -> bool (True if covered).
"""
return self._check_catalog_coverages(ra=ra, dec=dec, radius_deg=radius_deg, verbose=verbose, search_all=search_all)
def change_catalog(self, catalog_key):
"""
Change the current catalog to query.
Parameters
----------
catalog_key : str
The catalog key to change to.
Returns
-------
None
"""
self._change_catalog(catalog_key)
def show_available_catalogs(self):
"""
Display available catalogs.
Parameters
----------
None
Returns
-------
list of catalog keys: list
List of available catalogs.
"""
return self._show_available_catalogs()
def _split_query_regions(self,
width: int,
height: int,
ra: float,
dec: float,
fov: float,
rotation_angle: float = 0.0,
max_pixels: int = 45000000,
margin_fraction: float = 0.1,
verbose: bool = True):
"""
Split a large query into tiles with unique RA/Dec centers that cover the full region.
Returns
-------
list of dict : Each dict contains query parameters for a tile
"""
total_pixels = width * height
if total_pixels <= max_pixels:
width_with_margin = int(np.ceil(width * (1 + margin_fraction)))
height_with_margin = int(np.ceil(height * (1 + margin_fraction)))
fov_with_margin = fov * (1 + margin_fraction)
return [{
'width': width_with_margin,
'height': height_with_margin,
'ra': ra,
'dec': dec,
'fov': fov_with_margin,
'rotation_angle': rotation_angle,
'tile_id': '0_0'
}]
# Determine number of splits needed per axis
n_splits = int(np.ceil(np.sqrt(total_pixels / max_pixels)))
tile_width = width // n_splits
tile_height = height // n_splits
base_tile_fov = fov / n_splits # Original FoV per tile
# Apply margin to width/height and FoV
tile_width_with_margin = int(np.ceil(tile_width * (1 + margin_fraction)))
tile_height_with_margin = int(np.ceil(tile_height * (1 + margin_fraction)))
tile_fov_with_margin = base_tile_fov * (1 + margin_fraction)
# Pixel scale (deg/pixel) stays fixed
pixscale_deg = fov / width
dec_rad = np.deg2rad(dec)
if verbose:
print(f"[SPLIT] {width}x{height} >>> {n_splits}x{n_splits} tiles "
f"({tile_width}x{tile_height} px), margin = {margin_fraction:.1%}")
tile_params = []
for i in range(n_splits):
for j in range(n_splits):
# Offset in sky from the center (no margin added here)
delta_ra_deg = ((j + 0.5) - n_splits / 2) * tile_width * pixscale_deg / np.cos(dec_rad)
delta_dec_deg = ((n_splits / 2) - (i + 0.5)) * tile_height * pixscale_deg
center_ra = ra + delta_ra_deg
center_dec = dec + delta_dec_deg
tile_params.append({
'width': tile_width_with_margin,
'height': tile_height_with_margin,
'ra': center_ra,
'dec': center_dec,
'fov': tile_fov_with_margin,
'rotation_angle': rotation_angle,
'tile_id': f"{i}_{j}"
})
return tile_params
def _query_tile_worker(self, kwargs):
"""
Standalone worker function for multiprocessing.
This function must be top-level (not class method) for pickling.
"""
tile_param, save_path, verbose = kwargs
save_path = Path(save_path)
save_path = save_path.with_suffix(f".{tile_param['tile_id']}.fits")
return self._query(
wcs=None,
width=tile_param['width'],
height=tile_param['height'],
ra=tile_param['ra'],
dec=tile_param['dec'],
fov=tile_param['fov'],
rotation_angle=tile_param['rotation_angle'],
verbose=verbose,
save_path=save_path
)