Source code for vista.algorithms.tracks.extraction

"""
Track extraction algorithm for extracting image chips and detecting signal pixels.

This module implements track extraction that crops image chips around each track point,
detects signal pixels using CFAR-like thresholding, computes local noise statistics,
and optionally refines track coordinates using weighted centroids.
"""
import numpy as np
from numpy.typing import NDArray
from skimage.measure import label, regionprops
from vista.algorithms.detectors.cfar import CFAR
from vista.imagery.imagery import Imagery
from vista.tracks.track import Track


[docs] class TrackExtraction: """ Extract image chips and detect signal pixels around track points. For each track point, this algorithm: 1. Extracts a square image chip of specified diameter 2. Detects signal pixels using CFAR-like thresholding 3. Computes local noise standard deviation from background annulus 4. Optionally updates track coordinates to weighted centroid of signal blob Parameters ---------- track : Track Track object containing trajectory points imagery : Imagery Imagery object to extract chips from chip_radius : int Radius of square chips to extract (in pixels). Total chip diameter will be 2*radius + 1 background_radius : int Outer radius for background noise calculation (pixels) ignore_radius : int Inner radius to exclude from background (guard region, pixels) threshold_deviation : float Number of standard deviations above mean for signal detection annulus_shape : str, optional Shape of the annulus ('circular' or 'square'), by default 'circular' search_radius : int, optional When specified, only keep signal blobs that have at least one pixel within the central search region of this radius. By default None (keep all blobs) update_centroids : bool, optional If True, update track coordinates to signal blob centroids, by default False max_centroid_shift : float, optional Maximum allowed centroid shift in pixels. Points with larger shifts are not updated. By default np.inf (no limit) Attributes ---------- name : str Algorithm name ("Track Extraction") chip_diameter : int Computed chip diameter (2 * chip_radius + 1) Methods ------- __call__() Process all track points and return extraction results Returns ------- dict Dictionary with keys: - 'chips': NDArray with shape (n_points, diameter, diameter) - 'signal_masks': boolean NDArray with shape (n_points, diameter, diameter) - 'noise_stds': NDArray with shape (n_points,) - 'updated_rows': NDArray with shape (n_points,) - 'updated_columns': NDArray with shape (n_points,) Notes ----- - Chips near image edges are padded with np.nan values - Signal detection threshold: pixel > mean + threshold_deviation * std - Only the largest connected signal blob is used for centroid calculation - Centroid updates respect max_centroid_shift constraint """ name = "Track Extraction"
[docs] def __init__(self, track: Track, imagery: Imagery, chip_radius: int, background_radius: int, ignore_radius: int, threshold_deviation: float, annulus_shape: str = 'circular', search_radius: int = None, update_centroids: bool = False, max_centroid_shift: float = np.inf): # Validate chip_radius if not isinstance(chip_radius, int) or chip_radius <= 0: raise ValueError(f"chip_radius must be a positive integer, got {chip_radius}") self.track = track self.imagery = imagery self.chip_radius = chip_radius self.chip_diameter = 2 * chip_radius + 1 self.background_radius = background_radius self.ignore_radius = ignore_radius self.threshold_deviation = threshold_deviation self.annulus_shape = annulus_shape self.search_radius = search_radius self.update_centroids = update_centroids self.max_centroid_shift = max_centroid_shift # Create CFAR detector instance for chip processing self.cfar_detector = CFAR( background_radius=background_radius, ignore_radius=ignore_radius, threshold_deviation=threshold_deviation, annulus_shape=annulus_shape, search_radius=search_radius )
def _extract_chip(self, image: NDArray, row: float, col: float) -> NDArray: """ Extract a square chip from the image centered at (row, col). Handles edge cases by padding with np.nan where chip extends beyond image bounds. Parameters ---------- image : NDArray 2D image array row : float Row coordinate of chip center col : float Column coordinate of chip center Returns ------- NDArray Extracted chip of shape (chip_diameter, chip_diameter) """ radius = self.chip_diameter // 2 chip = np.full((self.chip_diameter, self.chip_diameter), np.nan, dtype=np.float32) # Calculate chip bounds in image coordinates row_center = int(np.round(row)) col_center = int(np.round(col)) chip_row_start = row_center - radius chip_row_end = row_center + radius + 1 chip_col_start = col_center - radius chip_col_end = col_center + radius + 1 # Calculate valid region that overlaps with image img_rows, img_cols = image.shape valid_row_start = max(0, chip_row_start) valid_row_end = min(img_rows, chip_row_end) valid_col_start = max(0, chip_col_start) valid_col_end = min(img_cols, chip_col_end) # Calculate corresponding region in chip chip_valid_row_start = valid_row_start - chip_row_start chip_valid_row_end = chip_valid_row_start + (valid_row_end - valid_row_start) chip_valid_col_start = valid_col_start - chip_col_start chip_valid_col_end = chip_valid_col_start + (valid_col_end - valid_col_start) # Copy valid region from image to chip if valid_row_end > valid_row_start and valid_col_end > valid_col_start: chip[chip_valid_row_start:chip_valid_row_end, chip_valid_col_start:chip_valid_col_end] = \ image[valid_row_start:valid_row_end, valid_col_start:valid_col_end] return chip def _compute_weighted_centroid(self, chip: NDArray, signal_mask: NDArray) -> tuple: """ Compute weighted centroid of signal blob. Parameters ---------- chip : NDArray Image chip signal_mask : NDArray Boolean mask of signal pixels Returns ------- tuple (centroid_row, centroid_col) relative to chip center, or (0, 0) if no signal """ if not np.any(signal_mask): return 0.0, 0.0 # Label connected components and find largest blob labeled = label(signal_mask) if labeled.max() == 0: return 0.0, 0.0 regions = regionprops(labeled, intensity_image=chip) # Find largest region largest_region = max(regions, key=lambda r: r.area) # Get weighted centroid centroid = largest_region.weighted_centroid # Convert to offset from chip center (accounting for pixel center at 0.5, 0.5) chip_center = self.chip_diameter // 2 centroid_offset_row = centroid[0] + 0.5 - chip_center centroid_offset_col = centroid[1] + 0.5 - chip_center return centroid_offset_row, centroid_offset_col
[docs] def __call__(self): """ Process all track points and extract chips with signal detection. Returns ------- dict Dictionary containing: - 'chips': Image chips array (n_points, diameter, diameter) - 'signal_masks': Signal pixel masks (n_points, diameter, diameter) - 'noise_stds': Noise standard deviations (n_points,) - 'updated_rows': Updated row coordinates (n_points,) - 'updated_columns': Updated column coordinates (n_points,) """ n_points = len(self.track) # Initialize output arrays chips = np.zeros((n_points, self.chip_diameter, self.chip_diameter), dtype=np.float32) signal_masks = np.zeros((n_points, self.chip_diameter, self.chip_diameter), dtype=bool) noise_stds = np.zeros(n_points, dtype=np.float32) updated_rows = self.track.rows.copy() updated_columns = self.track.columns.copy() # Build frame index for imagery imagery_frame_index = {frame: idx for idx, frame in enumerate(self.imagery.frames)} # Process each track point for i in range(n_points): frame = self.track.frames[i] row = self.track.rows[i] col = self.track.columns[i] # Get corresponding imagery frame if frame not in imagery_frame_index: # Frame not in imagery - fill with NaN chips[i, :, :] = np.nan signal_masks[i, :, :] = False noise_stds[i] = np.nan continue image_idx = imagery_frame_index[frame] image = self.imagery.images[image_idx] # Extract chip chip = self._extract_chip(image, row, col) chips[i, :, :] = chip # Use CFAR to detect signal pixels and compute noise std chip_center = self.chip_diameter // 2 signal_mask, noise_std = self.cfar_detector.process_chip( chip, search_center=(chip_center, chip_center) ) signal_masks[i, :, :] = signal_mask noise_stds[i] = noise_std # Update centroid if requested if self.update_centroids: centroid_offset_row, centroid_offset_col = \ self._compute_weighted_centroid(chip, signal_mask) # Check if shift is within allowed range shift_distance = np.sqrt(centroid_offset_row**2 + centroid_offset_col**2) if shift_distance <= self.max_centroid_shift: updated_rows[i] = row + centroid_offset_row updated_columns[i] = col + centroid_offset_col return { 'chips': chips, 'signal_masks': signal_masks, 'noise_stds': noise_stds, 'updated_rows': updated_rows, 'updated_columns': updated_columns, }