Source code for vista.algorithms.tracks.interpolation
"""
Track interpolation algorithm for filling missing frames in trajectories.
This module provides the TrackInterpolation class which fills gaps in track data
by interpolating missing frames between existing track points.
"""
import numpy as np
from numpy.typing import NDArray
from scipy.interpolate import interp1d
from vista.tracks.track import Track
[docs]
class TrackInterpolation:
"""
Interpolates missing frames in a track trajectory.
Takes a Track object that may have gaps in frame coverage and returns a new
Track with interpolated positions for all missing frames between the first
and last tracked frames.
Parameters
----------
track : Track
Input track that may have missing frames
method : str, optional
Interpolation method for scipy.interp1d. Options include:
- 'linear': Linear interpolation (default)
- 'nearest': Nearest-neighbor interpolation
- 'zero': Zero-order spline (piecewise constant)
- 'slinear': First-order spline
- 'quadratic': Second-order spline
- 'cubic': Third-order spline
By default 'linear'
Methods
-------
__call__()
Execute the interpolation and return results
Examples
--------
>>> interpolator = TrackInterpolation(track, method='linear')
>>> results = interpolator()
>>> interpolated_track = results['interpolated_track']
"""
[docs]
def __init__(self, track: Track, method: str = 'linear'):
"""
Initialize the track interpolation algorithm.
Parameters
----------
track : Track
Input track with potentially missing frames
method : str, optional
Interpolation method for scipy.interp1d, by default 'linear'
"""
self.track = track
self.method = method
[docs]
def __call__(self) -> dict:
"""
Execute interpolation on the track.
Returns
-------
dict
Dictionary containing:
- 'interpolated_track': Track object with all frames filled
- 'original_frames': Array of frame numbers that existed in original track
- 'interpolated_frames': Array of frame numbers that were interpolated
- 'n_interpolated': Number of frames that were interpolated
Raises
------
ValueError
If track has fewer than 2 points (cannot interpolate)
"""
# Validate input
if len(self.track.frames) < 2:
raise ValueError("Track must have at least 2 points to interpolate")
# Get existing track data
existing_frames = self.track.frames
existing_rows = self.track.rows
existing_columns = self.track.columns
# Determine the full range of frames to interpolate
min_frame = existing_frames.min()
max_frame = existing_frames.max()
all_frames = np.arange(min_frame, max_frame + 1)
# Find which frames are missing
missing_mask = ~np.isin(all_frames, existing_frames)
missing_frames = all_frames[missing_mask]
# If no missing frames, return a copy of the original track
if len(missing_frames) == 0:
return {
'interpolated_track': self.track.copy(),
'original_frames': existing_frames.copy(),
'interpolated_frames': np.array([], dtype=np.int_),
'n_interpolated': 0
}
# Create interpolation functions for rows and columns
try:
row_interp = interp1d(
existing_frames,
existing_rows,
kind=self.method,
assume_sorted=False,
fill_value='extrapolate'
)
col_interp = interp1d(
existing_frames,
existing_columns,
kind=self.method,
assume_sorted=False,
fill_value='extrapolate'
)
except Exception as e:
raise ValueError(f"Interpolation failed: {str(e)}")
# Interpolate positions for all frames
all_rows = row_interp(all_frames)
all_columns = col_interp(all_frames)
# Create new interpolated track
interpolated_track = Track(
name=self.track.name,
frames=all_frames,
rows=all_rows,
columns=all_columns,
sensor=self.track.sensor,
color=self.track.color,
marker=self.track.marker,
line_width=self.track.line_width,
marker_size=self.track.marker_size,
visible=self.track.visible,
tail_length=self.track.tail_length,
complete=self.track.complete,
show_line=self.track.show_line,
line_style=self.track.line_style,
labels=self.track.labels.copy()
)
# Return results
return {
'interpolated_track': interpolated_track,
'original_frames': existing_frames.copy(),
'interpolated_frames': missing_frames,
'n_interpolated': len(missing_frames)
}