#%%
import inspect
from pathlib import Path
from tqdm import tqdm
import pandas as pd
from typing import Union, List
from astropy.time import Time
from astropy.table import Table
from concurrent.futures import ProcessPoolExecutor
import numpy as np
from ezphot.dataobjects import Catalog
from ezphot.imageobjects import ScienceImage
from ezphot.helper import Helper
#%%
[docs]
class CatalogSet:
"""
CatalogSet class for managing a set of catalogs.
This class provides methods
1. Search for catalogs in the given folder.
2. Select catalogs with given criteria.
3. Exclude catalogs with given criteria.
4. Add catalogs to the set.
5. Merge catalogs into single table.
6. Select sources from each Catalog instance.
"""
[docs]
def __init__(self, catalogs: list[Catalog] = None):
"""
Initialize the CatalogSet class.
Parameters
----------
catalogs : list[Catalog], optional
List of Catalog instances. Defaults to None. If None, the CatalogSet will be dummy instance.
"""
self.helper = Helper()
self.catalogs = catalogs if catalogs is not None else []
self.target_catalogs = self.catalogs
self._last_filter = dict(
file_key=None,
filter=None,
exptime=None,
objname=None,
obs_start=None,
obs_end=None,
seeing=None,
depth=None,
observatory=None,
telname=None
)
self._df = None
self._target_df = None
self._last_mode = "select" # <-- Track last mode (select or exclude)
def __repr__(self):
txt = f"CatalogSet[n_selected/n_catalogs= {len(self.target_catalogs)}/{len(self.catalogs)}] \n"
txt += 'SELECT FILTER ============\n'
for key, value in self._last_filter.items():
prefix = "!" if self._last_mode == "exclude" and value is not None else ""
txt += f"{prefix}{key:>11} = {value}\n"
return txt
def help(self):
# Get all public methods from the class, excluding `help`
methods = [
(name, obj)
for name, obj in inspect.getmembers(self.__class__, inspect.isfunction)
if not name.startswith("_") and name != "help"
]
# Build plain text list with parameters
lines = []
for name, func in methods:
sig = inspect.signature(func)
params = [str(p) for p in sig.parameters.values() if p.name != "self"]
sig_str = f"({', '.join(params)})" if params else "()"
lines.append(f"- {name}{sig_str}")
# Final plain text output
help_text = ""
print(f"Help for {self.__class__.__name__}\n{help_text}\n\nPublic methods:\n" + "\n".join(lines))
def merge_catalogs(self,
max_distance_arcsec=3.0,
ra_key='X_WORLD',
dec_key='Y_WORLD',
data_keys=['MAGSKY_AUTO', 'MAGERR_AUTO', 'ZP_AUTO', 'ZPERR_AUTO', 'MAGSKY_APER', 'MAGERR_APER', 'ZP_APER', 'ZPERR_APER',
'MAGSKY_APER_1', 'MAGERR_APER_1', 'ZP_APER_1', 'ZPERR_APER_1', 'MAGSKY_APER_2', 'MAGERR_APER_2', 'ZP_APER_2', 'ZPERR_APER_2',
'MAGSKY_APER_3', 'MAGERR_APER_3', 'ZP_APER_3', 'ZPERR_APER_3', 'MAGSKY_CIRC', 'MAGERR_CIRC'],
join_type='outer'
):
"""
Merge catalogs into single table.
Parameters
----------
max_distance_arcsec : float, optional
Maximum distance in arcseconds for matching sources.
ra_key : str, optional
Column name for right ascension.
dec_key : str, optional
Column name for declination.
data_keys : list[str], optional
List of column names for data.
join_type : str, optional
Type of join to use.
Returns
-------
merged_tbl : astropy.table.Table
Merged table of catalogs.
metadata : dict
Metadata of catalogs.
"""
from astropy.coordinates import SkyCoord
import numpy as np
import pandas as pd
import astropy.units as u
from tqdm import tqdm
catalogs = self.target_catalogs
dfs = []
coords = []
metadata = {}
data_keys_all = [ra_key, dec_key] + data_keys
# Step 1: Load and preprocess all catalogs
for i, catalog in tqdm(enumerate(catalogs), total=len(catalogs), desc="Preparing catalogs"):
tbl = catalog.target_data.copy()
ra = tbl[ra_key]
dec = tbl[dec_key]
mask = np.isfinite(ra) & np.isfinite(dec)
tbl = tbl[mask]
if len(tbl) == 0 or np.sum(np.isfinite(tbl[ra_key]) & np.isfinite(tbl[dec_key])) == 0:
# Still create a dummy DataFrame with NaNs
n_dummy = 1 # You can make this 1 or 0, depending on downstream needs
# row = {'ra': [0] * n_dummy, 'dec': [0] * n_dummy}
row = {
'ra_basis': [0] * n_dummy,
'dec_basis': [0] * n_dummy
}
for key in data_keys_all:
colname = f"{key}_idx{i}"
row[colname] = [np.nan] * n_dummy
df = pd.DataFrame(row)
df['catalog_id'] = i
df['match_id'] = -1
dfs.append(df)
coords.append(SkyCoord([0]*n_dummy * u.deg, [0]*n_dummy * u.deg)) # dummy coords
metadata[i] = catalog.info.to_dict()
continue
metadata[i] = catalog.info.to_dict()
# row = {'ra': tbl[ra_key], 'dec': tbl[dec_key]}
row = {
'ra_basis': tbl[ra_key],
'dec_basis': tbl[dec_key]
}
for key in data_keys_all:
colname = f"{key}_idx{i}"
row[colname] = tbl[key] if key in tbl.colnames else np.full(len(tbl), np.nan)
df = pd.DataFrame(row)
df['catalog_id'] = i
df['match_id'] = -1 # placeholder
dfs.append(df)
coords.append(SkyCoord(tbl[ra_key].value * u.deg, tbl[dec_key].value * u.deg))
if len(dfs) == 0:
return None, {}
# Step 2: Initialize merged_df with first catalog
merged_df = dfs[0].copy()
merged_df['match_id'] = np.arange(len(merged_df))
for i in tqdm(range(1, len(dfs)), desc="Merging catalogs"):
c1 = SkyCoord(merged_df['ra_basis'].values * u.deg, merged_df['dec_basis'].values * u.deg)
c2 = coords[i]
df2 = dfs[i].copy()
# Match c2 ? c1
idx, d2d, _ = c2.match_to_catalog_sky(c1)
sep_mask = d2d.arcsec < max_distance_arcsec
df2.loc[sep_mask, 'match_id'] = merged_df.iloc[idx[sep_mask]]['match_id'].values
matched = df2[df2['match_id'] >= 0].copy()
unmatched = df2[df2['match_id'] < 0].copy()
# Assign new match_id for unmatched
if len(unmatched) > 0:
unmatched['match_id'] = np.arange(
merged_df['match_id'].max() + 1,
merged_df['match_id'].max() + 1 + len(unmatched)
)
# Avoid duplicated column merge
matched = matched[[col for col in matched.columns if col not in merged_df.columns or col == 'match_id']]
merged_df = pd.merge(merged_df, matched, on='match_id', how=join_type)
if join_type == 'outer' and len(unmatched) > 0:
merged_df = pd.concat([merged_df, unmatched], ignore_index=True)
merged_df = merged_df.drop_duplicates('match_id', keep = 'first')
# Step 3: Add detection count
main_key = data_keys[0]
match_cols = [col for col in merged_df.columns if col.startswith(main_key)]
merged_df['n_detections'] = merged_df[match_cols].notna().sum(axis=1)
# Remove columns that are completely NaN (dummy columns)
idx_cols = [col for col in merged_df.columns if '_idx' in col]
is_dummy_row = merged_df[idx_cols].isna().all(axis=1)
merged_df = merged_df[~is_dummy_row].copy()
merged_tbl = Table.from_pandas(merged_df)
# Add coord column
coord = SkyCoord(ra=merged_tbl['ra_basis'] * u.deg, dec=merged_tbl['dec_basis'] * u.deg)
merged_tbl['coord'] = coord
merged_tbl.sort('n_detections', reverse=True)
return merged_tbl, metadata
def exclude_catalogs(self,
file_key=None,
filter=None,
exptime=None,
objname=None,
obs_start=None,
obs_end=None,
seeing=None,
depth=None,
observatory=None,
telname=None):
"""
Exclude catalogs that match the given criteria from self.catalogs.
Select catalogs from self.catalogs and update self.target_catalogs.
Parameters
----------
file_key : str, optional
File key to exclude.
filter : str, optional
Filter to exclude.
exptime : float, optional
Exposure time to exclude.
objname : str, optional
Object name to exclude.
obs_start : str, optional
Observation start time to exclude.
obs_end : str, optional
Observation end time to exclude.
seeing : float, optional
Seeing to exclude.
depth : float, optional
Depth to exclude.
observatory : str, optional
Observatory to exclude.
telname : str, optional
Telescope name to exclude.
Returns
-------
None
"""
df = self.df
# Convert inputs to arrays
if file_key is not None:
file_key = np.atleast_1d(file_key)
for key in file_key:
key = key.replace('*', '') if '*' in key else key
df = df[~df['path'].str.contains(key)]
if filter is not None:
filter = np.atleast_1d(filter)
df = df[~df['filter'].isin(filter)]
if exptime is not None:
exptime = np.atleast_1d(exptime)
df = df[~df['exptime'].isin(exptime)]
if objname is not None:
objname = np.atleast_1d(objname)
df = df[~df['objname'].isin(objname)]
if obs_start is not None:
obs_start = self.helper.flexible_time_parser(obs_start)
df = df[Time(df['obsdate'].tolist()) < obs_start]
if obs_end is not None:
obs_end = self.helper.flexible_time_parser(obs_end)
df = df[Time(df['obsdate'].tolist()) > obs_end]
if seeing is not None:
df = df[df['seeing'] >= seeing]
if depth is not None:
df = df[df['depth'] <= depth]
if observatory is not None:
observatory = np.atleast_1d(observatory)
df = df[~df['observatory'].isin(observatory)]
if telname is not None:
telname = np.atleast_1d(telname)
df = df[~df['telname'].isin(telname)]
# Update target_catalogs
if df.empty:
self.target_catalogs = []
else:
self.target_catalogs = [self.catalogs[i] for i in df.index]
self._last_filter = {
'file_key': file_key,
'filter': filter,
'exptime': exptime,
'objname': objname,
'obs_start': obs_start,
'obs_end': obs_end,
'seeing': seeing,
'depth': depth,
'observatory': observatory,
'telname': telname,
}
print(f"[INFO] Excluded catalogs based on given criteria. Remaining: {len(self.target_catalogs)}")
def select_catalogs(self,
file_key=None,
filter=None,
exptime=None,
objname=None,
obs_start=None,
obs_end=None,
seeing=None,
depth=None,
observatory=None,
telname=None):
"""
Select catalogs that match the given criteria from self.catalogs.
Select catalogs from self.catalogs and update self.target_catalogs.
Parameters
----------
file_key : str, optional
File key to select.
filter : str, optional
Filter to select.
exptime : float, optional
Exposure time to select.
objname : str, optional
Object name to select.
obs_start : str, optional
Observation start time to select.
obs_end : str, optional
Observation end time to select.
seeing : float, optional
Seeing to select.
depth : float, optional
Depth to select.
observatory : str, optional
Observatory to select.
telname : str, optional
Telescope name to select.
Returns
-------
None
"""
df = self.df
# Convert inputs to arrays
if file_key is not None:
file_key = np.atleast_1d(file_key)
for key in file_key:
key = key.replace('*', '') if '*' in key else key
df = df[df['path'].str.contains(key)]
if filter is not None:
filter = np.atleast_1d(filter)
df = df[df['filter'].isin(filter)]
if exptime is not None:
exptime = np.atleast_1d(exptime)
df = df[df['exptime'].isin(exptime)]
if objname is not None:
objname = np.atleast_1d(objname)
df = df[df['objname'].isin(objname)]
if obs_start is not None:
obs_start = self.helper.flexible_time_parser(obs_start)
df = df[Time(df['obsdate'].tolist()) >= obs_start]
if obs_end is not None:
obs_end = self.helper.flexible_time_parser(obs_end)
df = df[Time(df['obsdate'].tolist()) <= obs_end]
if seeing is not None:
df = df[df['seeing'] < seeing]
if depth is not None:
df = df[df['depth'] > depth]
if observatory is not None:
observatory = np.atleast_1d(observatory)
df = df[df['observatory'].isin(observatory)]
if telname is not None:
telname = np.atleast_1d(telname)
df = df[df['telname'].isin(telname)]
# Update target_catalogs
if df.empty:
self.target_catalogs = []
else:
self.target_catalogs = [self.catalogs[i] for i in df.index]
self._last_filter = {
'file_key': file_key,
'filter': filter,
'exptime': exptime,
'objname': objname,
'obs_start': obs_start,
'obs_end': obs_end,
'seeing': seeing,
'depth': depth,
'observatory': observatory,
'telname': telname,
}
self._last_mode = "select" # <-- mark as select
def divide_catalogs(
self,
by_filter: bool = False, # True when LC is required
by_exptime: bool = False,
by_objname: bool = True,
by_telname: bool = False,
by_observatory: bool = False,
by_obsdate: bool = True, # False when LC is required
obsdate_delta: float = 0.5,
obsdate_key: str = 'obsdate',
):
"""
Divide CatalogSet into multiple CatalogSet groups.
"""
group_keys = ['filter', 'exptime', 'objname', 'observatory', 'telname', 'group']
group_bools = [by_filter, by_exptime, by_objname, by_observatory, by_telname, by_obsdate]
group_keys_applied = [k for k, b in zip(group_keys, group_bools) if b]
df = self.target_df
df['mjd'] = Time(df['obsdate'].tolist()).mjd
groupped_tbl = self.helper.group_table(Table.from_pandas(df), key = 'mjd', tolerance = obsdate_delta)
df = groupped_tbl.to_pandas()
if df.empty:
return []
if by_obsdate:
tbl = Table.from_pandas(df)
tbl['mjd'] = Time(tbl[obsdate_key].tolist()).mjd
tbl = self.helper.group_table(tbl, key='mjd', tolerance=obsdate_delta)
else:
tbl = Table.from_pandas(df)
groups = tbl.group_by(group_keys_applied).groups
all_sets = []
for g in groups:
catalogs = [row['catalog'] for row in g]
all_sets.append(CatalogSet(catalogs))
return all_sets
def select_sources(self,
x,
y,
unit: str = 'coord',
matching_radius: float = 60,
x_key: str = 'X_WORLD',
y_key: str = 'Y_WORLD',
):
"""
Select sources from all catalogs within the given radius around the input coordinates.
Each catalog will be updated with the selected sources.
Parameters
----------
ra : float
Right Ascension in degrees.
dec : float
Declination in degrees.
radius : float
Search radius in arcseconds.
Returns
-------
None
"""
results = []
for cat in tqdm(self.target_catalogs, desc = 'Selecting sources...'):
cat.select_sources(x, y, unit=unit, matching_radius=matching_radius, x_key=x_key, y_key=y_key)
# @property
# def df(self):
# """
# Return a DataFrame containing metadata of all catalogs.
# """
# if len(self.catalogs) == 0:
# return pd.DataFrame()
# rows = []
# for cat in self.catalogs:
# info = cat.info
# rows.append({
# 'catalog': cat,
# 'path': info.path,
# 'filter': info.filter,
# 'exptime': info.exptime,
# 'obsdate': info.obsdate,
# 'observatory': info.observatory,
# 'telname': info.telname,
# 'objname': info.objname,
# 'seeing': info.seeing,
# 'depth': info.depth,
# 'ra': info.ra,
# 'dec': info.dec,
# 'fov_ra': info.fov_ra,
# 'fov_dec': info.fov_dec,
# })
# return pd.DataFrame(rows)
@property
def target_df(self):
"""DataFrame of selected catalogs."""
if len(self.target_catalogs) == 0:
return pd.DataFrame()
# indices of selected catalogs in master list
indices = [self.catalogs.index(cat) for cat in self.target_catalogs]
return self.df.loc[indices].copy()
@property
def df(self):
"""Pandas DataFrame of all catalogs (cached)."""
if self._df is not None:
return self._df
if len(self.catalogs) == 0:
self._df = pd.DataFrame()
return self._df
rows = []
for cat in self.catalogs:
info = cat.info
rows.append({
'catalog': cat,
'path': info.path,
'filter': info.filter,
'exptime': info.exptime,
'obsdate': info.obsdate,
'observatory': info.observatory,
'telname': info.telname,
'objname': info.objname,
'seeing': info.seeing,
'depth': info.depth,
'ra': info.ra,
'dec': info.dec,
'fov_ra': info.fov_ra,
'fov_dec': info.fov_dec,
})
self._df = pd.DataFrame(rows)
return self._df
def _load_catalog_worker(self, args):
catalog_file, existing_paths = args
try:
if str(catalog_file) in existing_paths:
return 'skipped', str(catalog_file), None
catalog = Catalog(path=catalog_file, catalog_type='all', load=True)
if not catalog.is_loaded:
load_result = catalog.load_target_img(target_img=None)
if catalog.is_loaded:
return 'success', str(catalog_file), catalog
else:
return 'failed', str(catalog_file), None
except Exception as e:
return 'failed', str(catalog_file), None