Source code for vista.algorithms.background_removal.subspace_background_removal

"""Sliding subspace background removal for VISTA

This module implements a sliding-window subspace (low-rank SVD) background estimator
using NumPy. For each frame, a background estimate is formed from surrounding frames
(with a temporal gap to exclude target signal), projected onto a low-rank subspace
via truncated SVD.

The temporal gap ensures that slowly-moving unresolved targets do not contaminate
the background estimate for the frame being processed.

When tiling is enabled, each frame is divided into non-overlapping square tiles that
are processed independently, reducing the per-SVD matrix size.
"""
import numpy as np

try:
    import torch
    HAS_TORCH = True
except ImportError:
    HAS_TORCH = False


[docs] def find_knee(singular_values): """ Find the knee (elbow) point in a curve of singular values. Uses the maximum distance method: draws a line from the first to last point in the curve, then finds the point with the greatest perpendicular distance from that line. Parameters ---------- singular_values : torch.Tensor 1D tensor of singular values in descending order Returns ------- int Rank at the knee point (1-indexed, minimum 1). Raises ------ ImportError If PyTorch is not installed. """ if not HAS_TORCH: raise ImportError("PyTorch is required for find_knee(). Use _find_knee_numpy() instead.") n = len(singular_values) if n <= 2: return 1 x = torch.arange(n, device=singular_values.device, dtype=singular_values.dtype) y = singular_values x0, y0 = x[0], y[0] x1, y1 = x[-1], y[-1] dx = x1 - x0 dy = y1 - y0 line_len = torch.sqrt(dx * dx + dy * dy) if line_len == 0: return 1 distances = torch.abs(dy * x - dx * y + x1 * y0 - y1 * x0) / line_len knee_index = torch.argmax(distances).item() return max(1, knee_index + 1)
def _find_knee_numpy(singular_values): """ Find the knee (elbow) point in a curve of singular values using NumPy. Parameters ---------- singular_values : numpy.ndarray 1D array of singular values in descending order Returns ------- int Rank at the knee point (1-indexed, minimum 1). """ n = len(singular_values) if n <= 2: return 1 x = np.arange(n, dtype=np.float64) y = singular_values.astype(np.float64) dx = x[-1] - x[0] dy = y[-1] - y[0] line_len = np.sqrt(dx * dx + dy * dy) if line_len == 0: return 1 distances = np.abs(dy * x - dx * y + x[-1] * y[0] - y[-1] * x[0]) / line_len knee_index = int(np.argmax(distances)) return max(1, knee_index + 1) def _images_to_tiles(images, tile_size): """ Reshape a stack of images into non-overlapping square tiles. Pads images to multiples of tile_size if necessary (zero-padding), then reshapes each frame into a grid of tiles. Parameters ---------- images : numpy.ndarray 3D array of shape (num_frames, height, width) tile_size : int Size of each square tile in pixels Returns ------- tuple of (numpy.ndarray, int, int) (tiles, original_height, original_width) where tiles has shape (num_frames, num_tiles, tile_size * tile_size) """ num_frames, height, width = images.shape pad_h = (tile_size - height % tile_size) % tile_size pad_w = (tile_size - width % tile_size) % tile_size if pad_h > 0 or pad_w > 0: images = np.pad(images, ((0, 0), (0, pad_h), (0, pad_w))) padded_h = images.shape[1] padded_w = images.shape[2] tiles_y = padded_h // tile_size tiles_x = padded_w // tile_size # (N, H, W) -> (N, tiles_y, tile_size, tiles_x, tile_size) # -> (N, tiles_y, tiles_x, tile_size, tile_size) # -> (N, num_tiles, tile_pixels) tiles = images.reshape(num_frames, tiles_y, tile_size, tiles_x, tile_size) tiles = tiles.transpose(0, 1, 3, 2, 4) tiles = tiles.reshape(num_frames, tiles_y * tiles_x, tile_size * tile_size) return tiles, height, width def _tiles_to_image(tiles, tile_size, original_height, original_width): """ Reshape tiles back into images, cropping to the original size. Parameters ---------- tiles : numpy.ndarray 3D array of shape (num_frames, num_tiles, tile_size * tile_size) tile_size : int Size of each square tile original_height : int Original image height (before padding) original_width : int Original image width (before padding) Returns ------- numpy.ndarray 3D array of shape (num_frames, original_height, original_width) """ num_frames = tiles.shape[0] pad_h = (tile_size - original_height % tile_size) % tile_size pad_w = (tile_size - original_width % tile_size) % tile_size padded_h = original_height + pad_h padded_w = original_width + pad_w tiles_y = padded_h // tile_size tiles_x = padded_w // tile_size # (N, num_tiles, tile_pixels) -> (N, tiles_y, tiles_x, tile_size, tile_size) # -> (N, tiles_y, tile_size, tiles_x, tile_size) # -> (N, padded_h, padded_w) images = tiles.reshape(num_frames, tiles_y, tiles_x, tile_size, tile_size) images = images.transpose(0, 1, 3, 2, 4) images = images.reshape(num_frames, padded_h, padded_w) return images[:, :original_height, :original_width]
[docs] def subspace_background_removal(images, rank=5, window_size=25, gap_size=3, tile_size=None, callback=None): """ Remove background from imagery using a sliding-window low-rank SVD approach. For each frame t, background reference frames are selected from a symmetric window around t, excluding frames within a gap of +/- gap_size. A truncated SVD of the reference frames captures the low-rank background subspace, and the current frame is projected onto it to estimate the background. When tile_size is provided, each frame is divided into non-overlapping square tiles that are processed independently, reducing the per-SVD matrix size. Parameters ---------- images : numpy.ndarray 3D array of shape (num_frames, height, width) with dtype float32. rank : int or None, optional Number of singular values to retain for the background subspace, by default 5. Higher values capture more complex backgrounds but may include target signal. If None, automatically selected via knee detection per frame. window_size : int, optional Number of reference frames to use on each side of the current frame, by default 25. Total potential reference frames = 2 * window_size (minus those in the gap). gap_size : int, optional Number of frames to exclude on each side of the current frame, by default 3. Prevents target signal near frame t from leaking into the background estimate. tile_size : int or None, optional Size of square tiles for processing, by default None (no tiling). When provided, images are divided into tile_size x tile_size tiles and each tile is processed independently. Recommended values: 32, 64, or 128. callback : callable, optional Called after each frame with (frame_processed, total_frames). Should return False to cancel processing. Returns ------- tuple of (numpy.ndarray, numpy.ndarray) (background, foreground) with the same shape and dtype as input. Raises ------ InterruptedError If the callback returns False (user cancellation). """ num_frames, height, width = images.shape if tile_size is not None: tiles, orig_h, orig_w = _images_to_tiles(images, tile_size) num_tiles = tiles.shape[1] bg_tiles = np.zeros_like(tiles) fg_tiles = np.zeros_like(tiles) for t in range(num_frames): refs = [] for i in range(max(0, t - window_size), min(num_frames, t + window_size + 1)): if abs(i - t) > gap_size: refs.append(i) if len(refs) == 0: bg_tiles[t] = tiles[t] else: for ti in range(num_tiles): # Reference data for this tile: (tile_pixels, n_ref) ref_data = tiles[refs, ti, :].T U, S, _ = np.linalg.svd(ref_data, full_matrices=False) if rank is None: k = _find_knee_numpy(S) else: k = min(rank, len(refs)) U_k = U[:, :k] current = tiles[t, ti, :] coeffs = U_k.T @ current projection = U_k @ coeffs bg_tiles[t, ti, :] = projection fg_tiles[t, ti, :] = current - projection if callback is not None: if not callback(t + 1, num_frames): raise InterruptedError("Processing cancelled by user") background = _tiles_to_image(bg_tiles, tile_size, orig_h, orig_w) foreground = _tiles_to_image(fg_tiles, tile_size, orig_h, orig_w) return background, foreground # Non-tiled path: process full flattened frames num_pixels = height * width flat_images = images.reshape(num_frames, num_pixels) background_flat = np.zeros_like(flat_images) foreground_flat = np.zeros_like(flat_images) for t in range(num_frames): refs = [] for i in range(max(0, t - window_size), min(num_frames, t + window_size + 1)): if abs(i - t) > gap_size: refs.append(i) if len(refs) == 0: background_flat[t] = flat_images[t] else: ref_data = flat_images[refs].T # (pixels, n_ref) U, S, _ = np.linalg.svd(ref_data, full_matrices=False) if rank is None: k = _find_knee_numpy(S) else: k = min(rank, len(refs)) U_k = U[:, :k] current = flat_images[t] coeffs = U_k.T @ current projection = U_k @ coeffs background_flat[t] = projection foreground_flat[t] = current - projection if callback is not None: if not callback(t + 1, num_frames): raise InterruptedError("Processing cancelled by user") return (background_flat.reshape(num_frames, height, width), foreground_flat.reshape(num_frames, height, width))