Source code for ezphot.dataobjects.catalog

#%%
import os
import json
import logging
import inspect
from pathlib import Path
from typing import Union, Optional
import fnmatch

import numpy as np
from numba import jit
from astropy.io import fits
from astropy.io.fits import Header
from astropy.time import Time
from astropy.table import Table
from astropy.coordinates import SkyCoord
from scipy.spatial import cKDTree
from astropy import units as u
from astropy.wcs.utils import skycoord_to_pixel
import matplotlib.pyplot as plt

from types import SimpleNamespace
from ezphot.imageobjects import ScienceImage, ReferenceImage, Mask
from ezphot.helper import Helper


class Info:
    """Stores metadata of a FITS image with dot-access."""
    
    INFO_FIELDS = ["path", "target_img", "obsdate", "filter", "exptime", "depth", "seeing",
                   "catalog_type", "ra", "dec", "fov_ra", "fov_dec", 'objname',
                   "observatory", "telname"]
    DEFAULT_VALUES = [None] * len(INFO_FIELDS)

    def __init__(self, **kwargs):
        # Set defaults, then override with user-provided values
        self._fields = {
            field: kwargs.get(field, default)
            for field, default in zip(self.INFO_FIELDS, self.DEFAULT_VALUES)
        }

    def __getattr__(self, name):
        # Prevent infinite recursion when _fields is not yet initialized
        if '_fields' in self.__dict__ and name in self._fields:
            return self._fields[name]
        raise AttributeError(f"'Info' object has no attribute '{name}'")

    def __setattr__(self, name, value):
        if name == "_fields":
            super().__setattr__(name, value)
        elif "_fields" in self.__dict__ and name in self._fields:
            self._fields[name] = value
        else:
            raise AttributeError(f"'Info' object has no attribute '{name}'")

    def update(self, key, value):
        if key in self._fields:
            self._fields[key] = value
        else:
            print(f"WARNING: Invalid key: {key}")
        
    def copy(self):
        return Info.from_dict(self.to_dict())

    def to_dict(self):
        return dict(self._fields)

    @classmethod
    def from_dict(cls, data):
        return cls(**{key: data.get(key) for key in cls.INFO_FIELDS})

    def __repr__(self):
        lines = [f"{key}: {value}" for key, value in self._fields.items()]
        return "Info ============================================\n  " + "\n  ".join(lines) + "\n==================================================="


[docs] class Catalog: """ Catalog class for handling source catalog data. This class is designed to handle source catalog data from a FITS file. It automatically loads the corresponding target image from the FITS file. """
[docs] def __init__(self, path: Union[Path, str], catalog_type: str = 'all', info: Info = None, load: bool = True): """ Initialize the Catalog instance. Parameters ---------- path : Union[Path, str] Path to the catalog file. catalog_type : str, optional Catalog type. Default is 'all'. ['all', 'reference', 'valid', 'transient', 'forced'] info : Info, optional Info object. Default is None. load : bool, optional Whether to load the catalog data. Default is True. """ path = Path(path) if catalog_type not in ['all', 'reference', 'valid', 'transient', 'candidate', 'forced']: raise ValueError(f"Invalid catalog type: {catalog_type}") self.helper = Helper() self.is_loaded = False self.path = path self.catalog_type = catalog_type self.target_img = None self._data = None self._target_data = None self.info = Info(path = str(path), catalog_type = catalog_type) if load: self.load_info() if info is not None: self.info = info self.info.path = str(self.path) self.info.catalog_type = catalog_type if self.info.target_img is not None: if Path(self.info.target_img).exists(): self.target_img = ScienceImage(self.info.target_img, load = True) if self.target_img is None: self._load_target_img()
def __repr__(self): return f"Catalog( N_selected/N_sources = {self.nselected}/{self.nsources}, is_exists={self.is_exists}, catalog_type={self.catalog_type}, path={self.path})" def help(self): # Get all public methods from the class, excluding `help` methods = [ (name, obj) for name, obj in inspect.getmembers(self.__class__, inspect.isfunction) if not name.startswith("_") and name != "help" ] # Build plain text list with parameters lines = [] for name, func in methods: sig = inspect.signature(func) params = [str(p) for p in sig.parameters.values() if p.name != "self"] sig_str = f"({', '.join(params)})" if params else "()" lines.append(f"- {name}{sig_str}") # Final plain text output help_text = "" print(f"Help for {self.__class__.__name__}\n{help_text}\n\nPublic methods:\n" + "\n".join(lines)) def select_sources(self, x, y, unit='coord', matching_radius=5.0, x_key='X_WORLD', y_key='Y_WORLD'): """ Select all sources within matching_radius. """ x = np.atleast_1d(x) y = np.atleast_1d(y) target_catalog = self.data if len(target_catalog) == 0: self.target_data = target_catalog return if unit == 'pixel': catalog_coords = np.vstack((target_catalog[x_key], target_catalog[y_key])).T input_coords = np.vstack((x, y)).T tree = cKDTree(catalog_coords) # List of lists of indices idx_lists = tree.query_ball_point(input_coords, r=matching_radius) # Flatten + unique idx = np.unique(np.concatenate(idx_lists)) if len(idx_lists) else [] self.target_data = target_catalog[idx] elif unit == 'coord': cat_sky = SkyCoord( ra=target_catalog[x_key], dec=target_catalog[y_key], unit='deg' ) input_sky = SkyCoord(ra=x, dec=y, unit='deg') # Cartesian coordinates for KDTree cat_xyz = np.vstack(cat_sky.cartesian.xyz).T input_xyz = np.vstack(input_sky.cartesian.xyz).T tree = cKDTree(cat_xyz) matching_radius_rad = np.deg2rad(matching_radius / 3600.0) idx_lists = tree.query_ball_point(input_xyz, r=matching_radius_rad) idx = np.unique(np.concatenate(idx_lists)) if len(idx_lists) else [] self.target_data = target_catalog[idx] else: raise ValueError("unit must be either 'pixel' or 'coord'") def show_source(self, ra: float, dec: float, radius_arcsec: float = 5.0, downsample: int = 4, zoom_radius_pixel: float = 50, matching_radius_arcsec: float = 3.0, ra_key: str = 'X_WORLD', dec_key: str = 'Y_WORLD', ): """ Show two-panel view of the target image with a single source marked (red). The left panel shows the full image, and the right panel shows a zoomed-in view of the target position. If a source is matched, it is marked with a blue circle. Parameters ---------- ra : float Right ascension of the target source. dec : float Declination of the target source. downsample : int, optional Downsampling factor for the image. Default is 4. zoom_radius_pixel : float, optional Radius of the zoomed-in view in pixels. Default is 50. matching_radius_arcsec : float, optional Matching radius in arcseconds. Default is 3.0. Returns ------- fig : matplotlib.figure.Figure Figure object. """ from astropy.coordinates import SkyCoord from astropy import units as u import matplotlib.pyplot as plt import numpy as np # Load image if not yet loaded if self.target_img is None: load_result = self._load_target_img(target_img=None) # Convert RA/Dec to pixel coordinates coord = SkyCoord(ra=ra * u.deg, dec=dec * u.deg) x, y = self.target_img.wcs.world_to_pixel(coord) # Match source in catalog self.select_sources( x=[ra], y=[dec], unit='coord', matching_radius=matching_radius_arcsec) matched_catalog = self.target_data # If matched, get pixel coords of catalog source matched_xy = None if len(matched_catalog) > 0: matched_coord = SkyCoord(ra=matched_catalog[0][ra_key] * u.deg, dec=matched_catalog[0][dec_key] * u.deg) matched_xy = self.target_img.wcs.world_to_pixel(matched_coord) # Downsampled dimensions for full image view image_shape = self.target_img.data.shape x_size_ds = image_shape[1] / downsample y_size_ds = image_shape[0] / downsample x_ds = x / downsample y_ds = y / downsample # Create figure with two subplots fig, (ax_full, ax_zoom) = plt.subplots(1, 2, figsize=(14, 6)) # --- Full image --- fig_full, _ = self.target_img.show(downsample=downsample, title=None) plt.close(fig_full) full_image = fig_full.axes[0].images[0] ax_full.imshow(full_image.get_array(), cmap=full_image.get_cmap(), origin='lower', vmin=full_image.get_clim()[0], vmax=full_image.get_clim()[1]) ax_full.set_title(f"Full Image of {self.target_img.objname}") ax_full.plot(x_ds, y_ds, 'ro', markersize=6, label='Requested Position') if matched_xy is not None: matched_x_ds, matched_y_ds = matched_xy[0] / downsample, matched_xy[1] / downsample ax_full.plot(matched_x_ds, matched_y_ds, 'bo', markersize=6, label='Matched Source') ax_full.legend() # --- Zoomed view --- fig_zoom, _ = self.target_img.show(downsample=1, title=None) plt.close(fig_zoom) zoom_image = fig_zoom.axes[0].images[0] ax_zoom.imshow(zoom_image.get_array(), cmap=zoom_image.get_cmap(), origin='lower', vmin=zoom_image.get_clim()[0], vmax=zoom_image.get_clim()[1]) ax_zoom.set_title("Zoom on Target") # Draw red circle for requested position pixel_scale = np.abs(self.target_img.wcs.pixel_scale_matrix[0, 0]) * 3600 # arcsec/pixel radius_pixel = radius_arcsec / pixel_scale circ = plt.Circle((x, y), radius_pixel, color='red', fill=False, linestyle='--', linewidth=2.5, alpha=0.7) ax_zoom.add_patch(circ) ax_zoom.text(x, y + 1.5 * radius_pixel, 'Requested', color='red', fontsize=13, ha='center', va='center') # If matched, draw blue circle if matched_xy is not None: matched_x, matched_y = matched_xy circ_match = plt.Circle((matched_x, matched_y), radius_pixel, color='blue', fill=False, linestyle='-', linewidth=2.0, alpha=0.8) ax_zoom.add_patch(circ_match) ax_zoom.text(matched_x, matched_y - 1.5 * radius_pixel, 'Matched', color='blue', fontsize=13, ha='center', va='center') # Zoom limits ax_zoom.set_xlim(x - zoom_radius_pixel, x + zoom_radius_pixel) ax_zoom.set_ylim(y - zoom_radius_pixel, y + zoom_radius_pixel) fig.tight_layout() plt.show() return fig def copy(self) -> "Catalog": """ Return a deep copy of this Catalog instance. Returns ------- copied_catalog : Catalog A deep copy of the Catalog instance. Examples -------- >>> catalog = Catalog(path='catalog.fits') >>> copied_catalog = catalog.copy() >>> copied_catalog.path 'catalog.fits' """ new_instance = Catalog( path=self.path, catalog_type=self.catalog_type, info=Info.from_dict(self.info.to_dict()), load=False ) # Manually copy loaded data and header new_instance.data = None if self.data is None else self.data.copy() new_instance.target_img = None if self.target_img is None else ScienceImage(self.target_img.path, telinfo=self.target_img.telinfo, load=True) return new_instance def write(self, format = 'ascii', verbose: bool = True): """ Write catalog data to the savepath (self.savepath.savepath). Parameters ---------- format : str, optional Format of the output file. Default is 'ascii'. Returns ------- None Examples -------- """ if self.data is None: raise ValueError("Cannot save MaskImage: data is not registered.") os.makedirs(self.savepath.savedir, exist_ok=True) # Write to disk self.data.write(self.savepath.savepath, format=format, overwrite=True) self.helper.print(f'Saved: {self.savepath.savepath}', verbose) self.save_info() def remove(self, remove_main: bool = True, remove_connected_files: bool = True, skip_patterns: list = [['*.png', '*.cat']], verbose: bool = True) -> dict: removed = {} def try_remove(p: Union[str, Path]): p = Path(p) if p.exists() and p.is_file(): try: p.unlink() if verbose: print(f"[REMOVE] {p}") return True except Exception as e: if verbose: print(f"[FAILED] {p} - {e}") return False return False def should_skip(p: Path): for pattern in skip_patterns: if fnmatch.fnmatch(p.name, pattern): return True return False # Remove main FITS file if remove_main and self.path and self.path.is_file(): removed[str(self.path)] = try_remove(self.path) # Remove connected files if remove_connected_files: for f in self.connected_files: if should_skip(f): if verbose: print(f"[SKIP] {f} (matched skip pattern)") continue removed[str(f)] = try_remove(f) return removed def clear(self, clear_data: bool = True, verbose: bool = True): """Clear the image data and/or header from memory. Parameters ---------- clear_data : bool, optional If True, clear the image data. Default is True. """ if clear_data: self._data = None self._target_data = None self.helper.print("Cleared data from memory.", verbose) def apply_mask(self, target_ivpmask: Mask, x_key: str = 'X_IMAGE', y_key: str = 'Y_IMAGE'): """ Apply a mask to the catalog. Parameters ---------- target_ivpmask : Mask Mask to apply to the catalog. x_key : str, optional Column name for X coordinates. Default is 'X_IMAGE'. y_key : str, optional Column name for Y coordinates. Default is 'Y_IMAGE'. Returns ------- masked_sources : Table Catalog with sources that are not masked. """ mask = target_ivpmask.data ny, nx = mask.shape # Round or convert positions to int for indexing x = np.round(self.data[x_key]).astype(int) y = np.round(self.data[y_key]).astype(int) # Ensure coordinates are within image bounds valid = (x >= 0) & (x < nx) & (y >= 0) & (y < ny) x_valid = x[valid] y_valid = y[valid] # Check if pixel is masked (== 0) mask_values = mask[y_valid, x_valid] is_masked = (mask_values == 0) # Apply final selection final_indices = np.where(valid)[0][is_masked] masked_sources = self.data[final_indices] return masked_sources def apply_zp(self, target_img: Union[ScienceImage, ReferenceImage], save: bool = True, verbose: bool = True): """ Apply photometric zeropoint corrections using values saved in the FITS header. Parameters ---------- target_img : ScienceImage or ReferenceImage The target image to apply photometric zeropoint corrections to. save : bool Whether to save the catalog. verbose : bool Whether to print verbose output. """ from ezphot.methods import PhotometricCalibration photometriccalibration = PhotometricCalibration() result = photometriccalibration.apply_zp( target_img = target_img, target_catalog = self, save = save, verbose = verbose) return result def to_stamp(self, target_img: Union[ScienceImage, ReferenceImage], sort_by: str = 'FLUX_AUTO', max_number: int = 50000, x_key: str = 'X_WORLD', y_key: str = 'Y_WORLD', ): """ Convert X_WORLD and Y_WORLD to pixel coordinates and save to a stamp catalog. Parameters ---------- target_img : ScienceImage or ReferenceImage Target image to convert to pixel coordinates. sort_by : str, optional Column name to sort the catalog by. Default is 'FLUX_AUTO'. max_number : int, optional Maximum number of sources to save. Default is 50000. Returns ------- stamppath : str Path to the stamp catalog. """ # Convert X_WORLD and Y_WORLD to pixel coordinates and save to a stamp catalog. stamppath = self.savepath.stamppath wcs = target_img.wcs if sort_by in self.data.colnames: self.data.sort(sort_by) ra_deg = self.data[x_key] dec_deg = self.data[y_key] skycoord = SkyCoord(ra=ra_deg, dec=dec_deg, unit = 'deg') x_pix, y_pix = skycoord_to_pixel(skycoord, wcs, origin=0) if len(x_pix) > max_number: x_pix = x_pix[:max_number] y_pix = y_pix[:max_number] # Save stamp catalog to a file. with open(stamppath, "w") as f: for x, y in zip(x_pix, y_pix): f.write(f"{round(x,3)} {round(y,3)} \n") return stamppath def to_region(self, reg_size: float = 6.0, shape : str = 'circle'): """ Convert X_IMAGE and Y_IMAGE to a region file. Parameters ---------- reg_size : float, optional Size of the region in pixels. Default is 6.0. shape : str, optional Shape of the region. Default is 'circle'. """ reg_x = self.data['X_IMAGE'] reg_y = self.data['Y_IMAGE'] reg_a = None reg_b = None reg_theta = None if shape != 'circle': if 'A_IMAGE' not in self.data.colnames or 'B_IMAGE' not in self.data.colnames or 'THETA_IMAGE' not in self.data.colnames: raise ValueError("For non-circle shapes, A_IMAGE, B_IMAGE, and THETA_IMAGE must be present in the catalog data.") reg_a = self.data['A_IMAGE'] reg_b = self.data['B_IMAGE'] reg_theta = self.data['THETA_IMAGE'] region_path = str(self.savepath.savepath) + '.reg' self.helper.to_regions(reg_x = reg_x, reg_y = reg_y, reg_a = reg_a, reg_b = reg_b, reg_theta = reg_theta, reg_size = reg_size, output_file_path = region_path) return region_path def save_info(self, verbose = False): """ Save processing info to a JSON file. Parameters ---------- verbose : bool, optional If True, print the path of the info file. Default is False. Returns ------- None """ with open(self.savepath.infopath, 'w') as f: json.dump(self.info.to_dict(), f, indent=4) self.helper.print(f"Saved: {self.savepath.infopath}", verbose) def load_info(self, verbose = False): """ Load processing info from a JSON file. Parameters ---------- verbose : bool, optional If True, print the path of the info file. Default is False. Returns ------- info : Info Info object loaded from the JSON file. """ if not self.savepath.infopath.exists(): self.helper.print(f"Info file does not exist: {self.savepath.infopath}", verbose) return self.info with open(self.savepath.infopath, 'r') as f: data = json.load(f) self.info = Info.from_dict(data) if self.info.target_img is not None: target_path = Path(self.info.target_img) if not target_path.exists(): self.helper.print(f"Target image does not exist: {target_path}", verbose) return self.info self.target_img = ScienceImage(target_path, load=True) self.is_loaded = True self.helper.print(f"Loaded: {self.savepath.infopath}", verbose) return self.info # def _find_corresponding_fits(self) -> Optional[Path]: # search_dirs = [self.path.parent, self.path.parent.parent] # # Iteravely strip suffixes from the path to find candidates # candidates = [] # path = self.path # while path.suffix: # path = path.with_suffix('') # if path.suffix.startswith('.fits'): # candidates.append(path.name) # else: # candidate = Path(str(path) + '.fits') # if candidate.name not in candidates: # candidates.append(candidate.name) # # Search for candidate names in possible directories # for directory in search_dirs: # for name in candidates: # candidate_path = directory / name # if candidate_path.exists(): # return candidate_path # print(f"[WARNING] No matching .fits found for: {self.path}") # return None def _find_corresponding_fits(self) -> Optional[Path]: search_dirs = [self.path.parent, self.path.parent.parent] candidates = [] filename = self.path.name def add_candidate(name: str): if name not in candidates: candidates.append(name) # 1. explicit pattern handling # e.g. T01222_m575_7DT14_20260305_024900_300s_coadd_cat.fits # -> T01222_m575_7DT14_20260305_024900_300s_coadd.fits if filename.endswith('_cat.fits'): add_candidate(filename.replace('_cat.fits', '.fits')) # ν•„μš”ν•˜λ©΄ 이런 νŒ¨ν„΄λ„ μΆ”κ°€ κ°€λŠ₯ # if filename.endswith('_phot.cat.fits'): # add_candidate(filename.replace('_phot.cat.fits', '.fits')) # 2. iterative suffix stripping (existing logic, slightly cleaned) path = self.path while path.suffix: path = path.with_suffix('') if path.suffix.startswith('.fits'): add_candidate(path.name) else: add_candidate(f"{path.name}.fits") # 3. search for directory in search_dirs: for name in candidates: candidate_path = directory / name if candidate_path.exists(): return candidate_path print(f"[WARNING] No matching .fits found for: {self.path}") return None def _load_target_img(self, target_img: Union[ScienceImage, ReferenceImage] = None): if target_img is None: target_path = self._find_corresponding_fits() if target_path is None: print(f"[ERROR] No corresponding FITS file found for {self.path}") return False target_img = ScienceImage(target_path) self.target_img = target_img self.info.target_img = str(target_img.path) self.info.ra = target_img.ra self.info.dec = target_img.dec self.info.fov_ra = target_img.fovx self.info.fov_dec = target_img.fovy self.info.objname = target_img.objname self.info.obsdate = target_img.obsdate self.info.filter = target_img.filter self.info.exptime = target_img.exptime self.info.depth = target_img.depth self.info.seeing = target_img.seeing self.info.observatory = target_img.observatory self.info.telname = target_img.telname self.is_loaded = True return True @property def savedir(self) -> Union[Path, None]: """ Return the directory where this image and associated files will be saved. If a custom savedir was set, use it. Otherwise, build from config and metadata. Returns None if required fields are not available. """ # Use manually set savedir if provided if hasattr(self, '_savedir') and self._savedir is not None: return self._savedir # Default construction from config base_dir = self.path.parent return base_dir @savedir.setter def savedir(self, value: Union[str, Path]): """ Set a custom directory for saving the image and associated products. """ if value is None: self._savedir = None return value = Path(value) if value.is_file(): value = value.parent self._savedir = value @property def savepath(self): """Dynamically builds save paths based on the path""" savedir = self.savedir filename = self.path.name return SimpleNamespace( savedir = savedir, savepath = savedir / filename, refcatalogpath = (savedir / filename).with_suffix('.refcat'), transientcatalogpath = (savedir / filename).with_suffix('.transient'), candidatecatalogpath = (savedir / filename).with_suffix('.candidate'), stamppath = savedir / (filename + '.stamp'), infopath= savedir / (filename + '.info')) @property def connected_files(self) -> set: """ Return all associated files that would be deleted in `remove()` if remove_connected_files=True, excluding the main FITS file (`self.path`). Only includes existing files, not directories. Returns ------- connected_files : set All connected auxiliary files. """ connected = set() # Files in same directory that start with the same base name (excluding self.path) base_dir = self.path.parent base_name = self.path.name for f in base_dir.iterdir(): if f.is_file() and f.name.startswith(base_name) and f != self.path: connected.add(f) # Files explicitly listed in savepath (excluding self.path) for p in vars(self.savepath).values(): if isinstance(p, Path) and p.exists() and p.is_file() and p != self.path: connected.add(p) return connected @property def data(self): """Lazy-load table data from path by trying multiple formats.""" if not self.is_data_loaded and self.is_exists: tried_formats = [ 'fits', 'ascii.sextractor', 'ascii', 'csv', 'ascii.basic', 'ascii.commented_header', 'ascii.tab', 'ascii.fast_no_header', ] for fmt in tried_formats: try: self._data = Table.read(self.path, format=fmt) self._target_data = self._data.copy() # Keep a copy of the original data return self._data # Success except Exception: continue # Try next format self._data = None return self._data @data.setter def data(self, value): self._data = value @property def is_data_loaded(self): """Check if the data is loaded.""" return self._data is not None @property def target_data(self): """Return the selected sources by self.select_sources().""" if self._target_data is None: return self.data return self._target_data @target_data.setter def target_data(self, value): """Set the target data and update the info.""" self._target_data = value @property def is_exists(self): """Check if the catalog file exists.""" return self.path.exists() @property def is_saved(self): """Check if the catalog has been saved.""" if self.savepath.savepath is None: return False return self.savepath.savepath.exists() @property def nsources(self): """Number of sources in the catalog.""" if self.is_data_loaded: return len(self.data) return None @property def nselected(self): """Number of selected sources in the target data.""" if self._target_data is not None: return len(self._target_data) return None