Source code for vista.algorithms.trackers.simple_tracker

"""Simple nearest-neighbor tracker with automatic parameter adaptation"""
import numpy as np
from scipy.optimize import linear_sum_assignment


[docs] class SimpleTrack: """ Simple track using running average for position prediction. This track class maintains a history of detected positions and predicts future positions using linear extrapolation based on recent velocity. Attributes ---------- id : int Unique identifier for this track. positions : list of ndarray List of (x, y) position arrays for each detection/prediction. frames : list of int Frame numbers corresponding to each position. max_search_radius : float Maximum distance to search for associated detections. hits : int Number of times this track was updated with a detection. misses : int Number of consecutive frames without a detection. age : int Total number of frames this track has existed. """
[docs] def __init__(self, detection_pos, frame, track_id, max_search_radius): """ Initialize a new track. Parameters ---------- detection_pos : ndarray Initial (x, y) position of the detection. frame : int Frame number where this track was initiated. track_id : int Unique identifier for this track. max_search_radius : float Maximum distance to search for associated detections. """ self.id = track_id self.positions = [detection_pos.copy()] self.frames = [frame] self.max_search_radius = max_search_radius # Track quality metrics self.hits = 1 self.misses = 0 self.age = 1
[docs] def predict_position(self): """ Predict next position using velocity estimate. Returns ------- ndarray Predicted (x, y) position for the next frame. If fewer than 2 positions exist, returns the last known position. Otherwise, uses linear extrapolation from the last 2-3 positions. Notes ----- The prediction uses up to the last 3 positions to estimate velocity, providing more stable predictions than using only the last 2 positions. """ if len(self.positions) < 2: # No velocity estimate yet, use last position return self.positions[-1] # Use last 3 positions for velocity estimate if available n = min(3, len(self.positions)) recent_positions = np.array(self.positions[-n:]) # Simple linear extrapolation if n >= 2: velocity = recent_positions[-1] - recent_positions[-2] return recent_positions[-1] + velocity else: return recent_positions[-1]
[docs] def distance_to(self, detection_pos): """ Compute Euclidean distance from detection to predicted position. Parameters ---------- detection_pos : ndarray The (x, y) position of a detection to compare against. Returns ------- float Euclidean distance between the detection and this track's predicted position. """ pred = self.predict_position() return np.linalg.norm(detection_pos - pred)
[docs] def update(self, detection_pos, frame): """ Update track with a new detection. Parameters ---------- detection_pos : ndarray The (x, y) position of the new detection. frame : int Frame number of the new detection. Notes ----- This method appends the new position to the track history, increments the hit count, resets the consecutive miss count, and increments age. """ self.positions.append(detection_pos.copy()) self.frames.append(frame) self.hits += 1 self.misses = 0 self.age += 1
[docs] def mark_missed(self, frame): """ Mark a frame as missed and add predicted position. Parameters ---------- frame : int Frame number that was missed (no detection associated). Notes ----- When a track is not associated with any detection in a frame, this method adds the predicted position to maintain track continuity. It increments both the consecutive miss count and the track age. """ pred_pos = self.predict_position() self.positions.append(pred_pos) self.frames.append(frame) self.misses += 1 self.age += 1
[docs] def quality_score(self): """ Compute track quality score (higher is better). Returns ------- float Quality score between 0 and 1, representing track reliability. Calculated as detection rate multiplied by a recency penalty factor that heavily penalizes recent misses. Notes ----- The quality score combines two factors: - Detection rate: ratio of hits to total age - Recency penalty: exponential decay based on consecutive misses This score helps identify high-quality tracks for retention and low-quality tracks for deletion. """ if self.age == 0: return 0 # Detection rate weighted by age detection_rate = self.hits / self.age # Penalize recent misses heavily recency_penalty = np.exp(-self.misses / 3.0) return detection_rate * recency_penalty
[docs] def run_simple_tracker(detectors, config): """ Run simple nearest-neighbor tracker with adaptive parameters. This tracker automatically adapts to the data and requires minimal configuration. It uses Hungarian algorithm for detection-to-track association and automatically computes search radius and maximum track age if not provided. Parameters ---------- detectors : list of Detector List of Detector objects to use as input. Each detector should have frames, columns, and rows attributes containing detection data. config : dict Dictionary containing tracker configuration with the following keys: - tracker_name : str, optional Name for the resulting tracker. Default is 'Simple Tracker'. - max_search_radius : float, optional Maximum distance to search for associations. If not provided, automatically computed from detection nearest-neighbor statistics. - min_track_length : int, optional Minimum number of detections for a valid track. Default is 5. - max_age : int, optional Maximum frames a track can go without detection before deletion. If not provided, automatically computed based on frame gaps. Returns ------- list of dict List of track data dictionaries, each containing: - 'frames' : ndarray Array of frame numbers where track appears. - 'rows' : ndarray Array of row (y) coordinates for each frame. - 'columns' : ndarray Array of column (x) coordinates for each frame. Notes ----- The tracker performs the following steps: 1. Collects all detections organized by frame 2. Auto-computes max_search_radius using 90th percentile of nearest-neighbor distances if not provided 3. Auto-computes max_age based on average frame gaps if not provided 4. Associates detections to tracks using Hungarian algorithm 5. Creates new tracks from unassociated detections 6. Marks missed detections for tracks without associations 7. Deletes tracks that exceed max_age or have poor quality scores 8. Returns tracks that meet minimum length requirement Examples -------- >>> config = { ... 'tracker_name': 'My Tracker', ... 'min_track_length': 5, ... 'max_search_radius': 10.0 ... } >>> tracks = run_simple_tracker(detectors, config) >>> print(f"Found {len(tracks)} tracks") """ # Extract configuration with smart defaults tracker_name = config.get('tracker_name', 'Simple Tracker') min_track_length = config.get('min_track_length', 5) # Collect all detections by frame detections_by_frame = {} all_detections_list = [] for detector in detectors: for i, frame in enumerate(detector.frames): pos = np.array([detector.columns[i], detector.rows[i]]) if frame not in detections_by_frame: detections_by_frame[frame] = [] detections_by_frame[frame].append(pos) all_detections_list.append(pos) # Auto-compute search radius if not provided if 'max_search_radius' in config and config['max_search_radius'] is not None: max_search_radius = config['max_search_radius'] else: # Estimate based on nearest-neighbor distances in detections if len(all_detections_list) > 10: all_dets = np.array(all_detections_list) # Compute pairwise distances for a sample sample_size = min(500, len(all_dets)) sample_indices = np.random.choice(len(all_dets), sample_size, replace=False) sample = all_dets[sample_indices] # Find 2nd nearest neighbor for each (1st is itself) dists = [] for det in sample[:100]: # Use subset for speed distances = np.linalg.norm(sample - det, axis=1) distances.sort() if len(distances) > 1: dists.append(distances[1]) # Use 90th percentile of nearest neighbor distances * 3 max_search_radius = np.percentile(dists, 90) * 3 if dists else 50.0 max_search_radius = max(10.0, min(100.0, max_search_radius)) # Clamp to reasonable range else: max_search_radius = 30.0 # Auto-compute max age if not provided if 'max_age' in config and config['max_age'] is not None: max_age = config['max_age'] else: # Estimate based on detection density frames = sorted(detections_by_frame.keys()) if len(frames) > 1: frame_gaps = np.diff(frames) avg_gap = np.mean(frame_gaps) # Allow tracks to survive ~3x the average frame gap max_age = int(max(3, min(10, avg_gap * 3))) else: max_age = 5 # Track management active_tracks = [] finished_tracks = [] # Tracks that were deleted but may still be valid next_track_id = 1 frames = sorted(detections_by_frame.keys()) # Process each frame for frame in frames: detections = detections_by_frame[frame] if len(active_tracks) > 0 and len(detections) > 0: # Build cost matrix (distances) cost_matrix = np.full((len(active_tracks), len(detections)), max_search_radius * 2) for i, track in enumerate(active_tracks): for j, detection in enumerate(detections): dist = track.distance_to(detection) if dist < max_search_radius: cost_matrix[i, j] = dist # Solve assignment track_indices, det_indices = linear_sum_assignment(cost_matrix) # Track assignments assigned_detections = set() assigned_tracks = set() for track_idx, det_idx in zip(track_indices, det_indices): if cost_matrix[track_idx, det_idx] < max_search_radius: active_tracks[track_idx].update(detections[det_idx], frame) assigned_detections.add(det_idx) assigned_tracks.add(track_idx) # Mark missed tracks for i, track in enumerate(active_tracks): if i not in assigned_tracks: track.mark_missed(frame) # Create new tracks from unassigned detections for j, detection in enumerate(detections): if j not in assigned_detections: new_track = SimpleTrack(detection, frame, next_track_id, max_search_radius) active_tracks.append(new_track) next_track_id += 1 elif len(detections) > 0: # No tracks yet, initialize from detections for detection in detections: new_track = SimpleTrack(detection, frame, next_track_id, max_search_radius) active_tracks.append(new_track) next_track_id += 1 else: # No detections, just mark all tracks as missed for track in active_tracks: track.mark_missed(frame) # Delete old/poor tracks tracks_to_remove = [] for track in active_tracks: # Delete if too many consecutive misses if track.misses > max_age: tracks_to_remove.append(track) # Delete if quality is too low and track is old enough elif track.age > 10 and track.quality_score() < 0.3: tracks_to_remove.append(track) for track in tracks_to_remove: active_tracks.remove(track) # Save deleted tracks that might still be valid if track.hits >= min_track_length: finished_tracks.append(track) # Convert to track data, filter by minimum length # Combine both active tracks and finished tracks all_valid_tracks = [] for track in active_tracks: if track.hits >= min_track_length: all_valid_tracks.append(track) all_valid_tracks.extend(finished_tracks) track_data_list = [] for track in all_valid_tracks: positions = np.array(track.positions) frames_array = np.array(track.frames, dtype=np.int_) track_data = { 'frames': frames_array, 'rows': positions[:, 1], # y 'columns': positions[:, 0], # x } track_data_list.append(track_data) return track_data_list