Smooth pose tracks#

Smooth pose tracks using the median and Savitzky-Golay filters.

Imports#

from matplotlib import pyplot as plt
from scipy.signal import welch

from movement import sample_data
from movement.filtering import (
    interpolate_over_time,
    median_filter,
    savgol_filter,
)

Load a sample dataset#

Let’s load a sample dataset and print it to inspect its contents. Note that if you are running this notebook interactively, you can simply type the variable name (here ds_wasp) in a cell to get an interactive display of the dataset’s contents.

ds_wasp = sample_data.fetch_dataset("DLC_single-wasp.predictions.h5")
print(ds_wasp)
<xarray.Dataset> Size: 61kB
Dimensions:      (time: 1085, space: 2, keypoints: 2, individuals: 1)
Coordinates:
  * time         (time) float64 9kB 0.0 0.025 0.05 0.075 ... 27.05 27.07 27.1
  * space        (space) <U1 8B 'x' 'y'
  * keypoints    (keypoints) <U7 56B 'head' 'stinger'
  * individuals  (individuals) <U12 48B 'individual_0'
Data variables:
    position     (time, space, keypoints, individuals) float64 35kB 1.086e+03...
    confidence   (time, keypoints, individuals) float64 17kB 0.05305 ... 0.0
Attributes:
    fps:              40.0
    time_unit:        seconds
    source_software:  DeepLabCut
    source_file:      /home/runner/.movement/data/poses/DLC_single-wasp.predi...
    ds_type:          poses
    frame_path:       /home/runner/.movement/data/frames/single-wasp_frame-10...
    video_path:       None

We see that the dataset contains the 2D pose tracks and confidence scores for a single wasp, generated with DeepLabCut. The wasp is tracked at two keypoints: “head” and “stinger” in a video that was recorded at 40 fps and lasts for approximately 27 seconds.

Define a plotting function#

Let’s define a plotting function to help us visualise the effects of smoothing both in the time and frequency domains. The function takes as inputs two datasets containing raw and smooth data respectively, and plots the position time series and power spectral density (PSD) for a given individual and keypoint. The function also allows you to specify the spatial coordinate (x or y) and a time range to focus on.

def plot_raw_and_smooth_timeseries_and_psd(
    ds_raw,
    ds_smooth,
    individual="individual_0",
    keypoint="stinger",
    space="x",
    time_range=None,
):
    # If no time range is specified, plot the entire time series
    if time_range is None:
        time_range = slice(0, ds_raw.time[-1])

    selection = {
        "time": time_range,
        "individuals": individual,
        "keypoints": keypoint,
        "space": space,
    }

    fig, ax = plt.subplots(2, 1, figsize=(10, 6))

    for ds, color, label in zip(
        [ds_raw, ds_smooth], ["k", "r"], ["raw", "smooth"], strict=False
    ):
        # plot position time series
        pos = ds.position.sel(**selection)
        ax[0].plot(
            pos.time,
            pos,
            color=color,
            lw=2,
            alpha=0.7,
            label=f"{label} {space}",
        )

        # interpolate data to remove NaNs in the PSD calculation
        pos_interp = interpolate_over_time(pos, print_report=False)

        # compute and plot the PSD
        freq, psd = welch(pos_interp, fs=ds.fps, nperseg=256)
        ax[1].semilogy(
            freq,
            psd,
            color=color,
            lw=2,
            alpha=0.7,
            label=f"{label} {space}",
        )

    ax[0].set_ylabel(f"{space} position (px)")
    ax[0].set_xlabel("Time (s)")
    ax[0].set_title("Time Domain")
    ax[0].legend()

    ax[1].set_ylabel("PSD (px$^2$/Hz)")
    ax[1].set_xlabel("Frequency (Hz)")
    ax[1].set_title("Frequency Domain")
    ax[1].legend()

    plt.tight_layout()
    fig.show()

Smoothing with a median filter#

Using the movement.filtering.median_filter() function on the position data variable, we can apply a rolling window median filter over a 0.1-second window (4 frames) to the wasp dataset. As the window parameter is defined in number of observations, we can simply multiply the desired time window by the frame rate of the video. We will also create a copy of the dataset to avoid modifying the original data.

window = int(0.1 * ds_wasp.fps)
ds_wasp_smooth = ds_wasp.copy()
ds_wasp_smooth.update({"position": median_filter(ds_wasp.position, window)})
Missing points (marked as NaN) in input
        Individual: individual_0
                head: 0/1085 (0.0%)
                stinger: 0/1085 (0.0%)

Missing points (marked as NaN) in output
        Individual: individual_0
                head: 0/1085 (0.0%)
                stinger: 0/1085 (0.0%)
<xarray.Dataset> Size: 61kB
Dimensions:      (time: 1085, space: 2, keypoints: 2, individuals: 1)
Coordinates:
  * time         (time) float64 9kB 0.0 0.025 0.05 0.075 ... 27.05 27.07 27.1
  * space        (space) <U1 8B 'x' 'y'
  * keypoints    (keypoints) <U7 56B 'head' 'stinger'
  * individuals  (individuals) <U12 48B 'individual_0'
Data variables:
    position     (time, space, keypoints, individuals) float64 35kB 1.086e+03...
    confidence   (time, keypoints, individuals) float64 17kB 0.05305 ... 0.0
Attributes:
    fps:              40.0
    time_unit:        seconds
    source_software:  DeepLabCut
    source_file:      /home/runner/.movement/data/poses/DLC_single-wasp.predi...
    ds_type:          poses
    frame_path:       /home/runner/.movement/data/frames/single-wasp_frame-10...
    video_path:       None


We see from the printed report that the dataset has no missing values neither before nor after smoothing. Let’s visualise the effects of the median filter in the time and frequency domains.

plot_raw_and_smooth_timeseries_and_psd(
    ds_wasp, ds_wasp_smooth, keypoint="stinger"
)
Time Domain, Frequency Domain

We see that the median filter has removed the “spikes” present around the 14 second mark in the raw data. However, it has not dealt the big shift occurring during the final second. In the frequency domain, we can see that the filter has reduced the power in the high-frequency components, without affecting the low frequency components.

This illustrates what the median filter is good at: removing brief “spikes” (e.g. a keypoint abruptly jumping to a different location for a frame or two) and high-frequency “jitter” (often present due to pose estimation working on a per-frame basis).

Choosing parameters for the median filter#

We can control the behaviour of the median filter via two parameters: window and min_periods. To better understand the effect of these parameters, let’s use a dataset that contains missing values.

ds_mouse = sample_data.fetch_dataset("SLEAP_single-mouse_EPM.analysis.h5")
print(ds_mouse)
<xarray.Dataset> Size: 1MB
Dimensions:      (time: 18485, space: 2, keypoints: 6, individuals: 1)
Coordinates:
  * time         (time) float64 148kB 0.0 0.03333 0.06667 ... 616.1 616.1 616.1
  * space        (space) <U1 8B 'x' 'y'
  * keypoints    (keypoints) <U9 216B 'snout' 'left_ear' ... 'tail_end'
  * individuals  (individuals) <U12 48B 'individual_0'
Data variables:
    position     (time, space, keypoints, individuals) float32 887kB nan ... ...
    confidence   (time, keypoints, individuals) float32 444kB nan nan ... 0.7607
Attributes:
    fps:              30.0
    time_unit:        seconds
    source_software:  SLEAP
    source_file:      /home/runner/.movement/data/poses/SLEAP_single-mouse_EP...
    ds_type:          poses
    frame_path:       /home/runner/.movement/data/frames/single-mouse_EPM_fra...
    video_path:       None

The dataset contains a single mouse with six keypoints tracked in 2D space. The video was recorded at 30 fps and lasts for ~616 seconds. We can see that there are some missing values, indicated as “nan” in the printed dataset. Let’s apply the median filter over a 0.1-second window (3 frames) to the dataset.

Missing points (marked as NaN) in input
        Individual: individual_0
                snout: 4494/18485 (24.3%)
                left_ear: 513/18485 (2.8%)
                right_ear: 533/18485 (2.9%)
                centre: 490/18485 (2.7%)
                tail_base: 704/18485 (3.8%)
                tail_end: 2496/18485 (13.5%)

Missing points (marked as NaN) in output
        Individual: individual_0
                snout: 5106/18485 (27.6%)
                left_ear: 678/18485 (3.7%)
                right_ear: 695/18485 (3.8%)
                centre: 640/18485 (3.5%)
                tail_base: 913/18485 (4.9%)
                tail_end: 3103/18485 (16.8%)
<xarray.Dataset> Size: 1MB
Dimensions:      (time: 18485, space: 2, keypoints: 6, individuals: 1)
Coordinates:
  * time         (time) float64 148kB 0.0 0.03333 0.06667 ... 616.1 616.1 616.1
  * space        (space) <U1 8B 'x' 'y'
  * keypoints    (keypoints) <U9 216B 'snout' 'left_ear' ... 'tail_end'
  * individuals  (individuals) <U12 48B 'individual_0'
Data variables:
    position     (time, space, keypoints, individuals) float32 887kB nan ... ...
    confidence   (time, keypoints, individuals) float32 444kB nan nan ... 0.7607
Attributes:
    fps:              30.0
    time_unit:        seconds
    source_software:  SLEAP
    source_file:      /home/runner/.movement/data/poses/SLEAP_single-mouse_EP...
    ds_type:          poses
    frame_path:       /home/runner/.movement/data/frames/single-mouse_EPM_fra...
    video_path:       None


The report informs us that the raw data contains NaN values, most of which occur at the snout and tail_end keypoints. After filtering, the number of NaNs has increased. This is because the default behaviour of the median filter is to propagate NaN values, i.e. if any value in the rolling window is NaN, the output will also be NaN.

To modify this behaviour, you can set the value of the min_periods parameter to an integer value. This parameter determines the minimum number of non-NaN values required in the window for the output to be non-NaN. For example, setting min_periods=2 means that two non-NaN values in the window are sufficient for the median to be calculated. Let’s try this.

ds_mouse_smooth.update(
    {"position": median_filter(ds_mouse.position, window, min_periods=2)}
)
Missing points (marked as NaN) in input
        Individual: individual_0
                snout: 4494/18485 (24.3%)
                left_ear: 513/18485 (2.8%)
                right_ear: 533/18485 (2.9%)
                centre: 490/18485 (2.7%)
                tail_base: 704/18485 (3.8%)
                tail_end: 2496/18485 (13.5%)

Missing points (marked as NaN) in output
        Individual: individual_0
                snout: 4455/18485 (24.1%)
                left_ear: 487/18485 (2.6%)
                right_ear: 507/18485 (2.7%)
                centre: 465/18485 (2.5%)
                tail_base: 673/18485 (3.6%)
                tail_end: 2428/18485 (13.1%)
<xarray.Dataset> Size: 1MB
Dimensions:      (time: 18485, space: 2, keypoints: 6, individuals: 1)
Coordinates:
  * time         (time) float64 148kB 0.0 0.03333 0.06667 ... 616.1 616.1 616.1
  * space        (space) <U1 8B 'x' 'y'
  * keypoints    (keypoints) <U9 216B 'snout' 'left_ear' ... 'tail_end'
  * individuals  (individuals) <U12 48B 'individual_0'
Data variables:
    position     (time, space, keypoints, individuals) float32 887kB nan ... ...
    confidence   (time, keypoints, individuals) float32 444kB nan nan ... 0.7607
Attributes:
    fps:              30.0
    time_unit:        seconds
    source_software:  SLEAP
    source_file:      /home/runner/.movement/data/poses/SLEAP_single-mouse_EP...
    ds_type:          poses
    frame_path:       /home/runner/.movement/data/frames/single-mouse_EPM_fra...
    video_path:       None


We see that this time the number of NaN values has decreased across all keypoints. Let’s visualise the effects of the median filter in the time and frequency domains. Here we focus on the first 80 seconds for the snout keypoint. You can adjust the keypoint and time_range arguments to explore other parts of the data.

plot_raw_and_smooth_timeseries_and_psd(
    ds_mouse, ds_mouse_smooth, keypoint="snout", time_range=slice(0, 80)
)
Time Domain, Frequency Domain

The smoothing once again reduces the power of high-frequency components, but the resulting time series stays quite close to the raw data.

What happens if we increase the window to 2 seconds (60 frames)?

window = int(2 * ds_mouse.fps)
ds_mouse_smooth.update(
    {"position": median_filter(ds_mouse.position, window, min_periods=2)}
)
Missing points (marked as NaN) in input
        Individual: individual_0
                snout: 4494/18485 (24.3%)
                left_ear: 513/18485 (2.8%)
                right_ear: 533/18485 (2.9%)
                centre: 490/18485 (2.7%)
                tail_base: 704/18485 (3.8%)
                tail_end: 2496/18485 (13.5%)

Missing points (marked as NaN) in output
        Individual: individual_0
                snout: 795/18485 (4.3%)
                left_ear: 80/18485 (0.4%)
                right_ear: 80/18485 (0.4%)
                centre: 80/18485 (0.4%)
                tail_base: 80/18485 (0.4%)
                tail_end: 239/18485 (1.3%)
<xarray.Dataset> Size: 1MB
Dimensions:      (time: 18485, space: 2, keypoints: 6, individuals: 1)
Coordinates:
  * time         (time) float64 148kB 0.0 0.03333 0.06667 ... 616.1 616.1 616.1
  * space        (space) <U1 8B 'x' 'y'
  * keypoints    (keypoints) <U9 216B 'snout' 'left_ear' ... 'tail_end'
  * individuals  (individuals) <U12 48B 'individual_0'
Data variables:
    position     (time, space, keypoints, individuals) float32 887kB nan ... ...
    confidence   (time, keypoints, individuals) float32 444kB nan nan ... 0.7607
Attributes:
    fps:              30.0
    time_unit:        seconds
    source_software:  SLEAP
    source_file:      /home/runner/.movement/data/poses/SLEAP_single-mouse_EP...
    ds_type:          poses
    frame_path:       /home/runner/.movement/data/frames/single-mouse_EPM_fra...
    video_path:       None


The number of NaN values has decreased even further. That’s because the chance of finding at least 2 valid values within a 2-second window (i.e. 60 frames) is quite high. Let’s plot the results for the same keypoint and time range as before.

plot_raw_and_smooth_timeseries_and_psd(
    ds_mouse, ds_mouse_smooth, keypoint="snout", time_range=slice(0, 80)
)
Time Domain, Frequency Domain

We see that the filtered time series is much smoother and it has even “bridged” over some small gaps. That said, it often deviates from the raw data, in ways that may not be desirable, depending on the application. Here, our choice of window may be too large. In general, you should choose a window that is small enough to preserve the original data structure, but large enough to remove “spikes” and high-frequency noise. Always inspect the results to ensure that the filter is not removing important features.

Smoothing with a Savitzky-Golay filter#

Here we apply the movement.filtering.savgol_filter() function (a wrapper around scipy.signal.savgol_filter()), to the position data variable. The Savitzky-Golay filter is a polynomial smoothing filter that can be applied to time series data on a rolling window basis. A polynomial with a degree specified by polyorder is applied to each data segment defined by the size window. The value of the polynomial at the midpoint of each window is then used as the output value.

Let’s try it on the mouse dataset, this time using a 0.2-second window (i.e. 6 frames) and the default polyorder=2 for smoothing. As before, we first compute the corresponding number of observations to be used as the window size.

window = int(0.2 * ds_mouse.fps)
ds_mouse_smooth.update({"position": savgol_filter(ds_mouse.position, window)})
Missing points (marked as NaN) in input
        Individual: individual_0
                snout: 4494/18485 (24.3%)
                left_ear: 513/18485 (2.8%)
                right_ear: 533/18485 (2.9%)
                centre: 490/18485 (2.7%)
                tail_base: 704/18485 (3.8%)
                tail_end: 2496/18485 (13.5%)

Missing points (marked as NaN) in output
        Individual: individual_0
                snout: 5810/18485 (31.4%)
                left_ear: 895/18485 (4.8%)
                right_ear: 905/18485 (4.9%)
                centre: 839/18485 (4.5%)
                tail_base: 1186/18485 (6.4%)
                tail_end: 3801/18485 (20.6%)
<xarray.Dataset> Size: 1MB
Dimensions:      (time: 18485, space: 2, keypoints: 6, individuals: 1)
Coordinates:
  * time         (time) float64 148kB 0.0 0.03333 0.06667 ... 616.1 616.1 616.1
  * space        (space) <U1 8B 'x' 'y'
  * keypoints    (keypoints) <U9 216B 'snout' 'left_ear' ... 'tail_end'
  * individuals  (individuals) <U12 48B 'individual_0'
Data variables:
    position     (time, space, keypoints, individuals) float32 887kB nan ... ...
    confidence   (time, keypoints, individuals) float32 444kB nan nan ... 0.7607
Attributes:
    fps:              30.0
    time_unit:        seconds
    source_software:  SLEAP
    source_file:      /home/runner/.movement/data/poses/SLEAP_single-mouse_EP...
    ds_type:          poses
    frame_path:       /home/runner/.movement/data/frames/single-mouse_EPM_fra...
    video_path:       None


We see that the number of NaN values has increased after filtering. This is for the same reason as with the median filter (in its default mode), i.e. if there is at least one NaN value in the window, the output will be NaN. Unlike the median filter, the Savitzky-Golay filter does not provide a min_periods parameter to control this behaviour. Let’s visualise the effects in the time and frequency domains.

plot_raw_and_smooth_timeseries_and_psd(
    ds_mouse, ds_mouse_smooth, keypoint="snout", time_range=slice(0, 80)
)
Time Domain, Frequency Domain

Once again, the power of high-frequency components has been reduced, but more missing values have been introduced.

Now let’s apply the same Savitzky-Golay filter to the wasp dataset.

window = int(0.2 * ds_wasp.fps)
ds_wasp_smooth.update({"position": savgol_filter(ds_wasp.position, window)})
Missing points (marked as NaN) in input
        Individual: individual_0
                head: 0/1085 (0.0%)
                stinger: 0/1085 (0.0%)

Missing points (marked as NaN) in output
        Individual: individual_0
                head: 0/1085 (0.0%)
                stinger: 0/1085 (0.0%)
<xarray.Dataset> Size: 61kB
Dimensions:      (time: 1085, space: 2, keypoints: 2, individuals: 1)
Coordinates:
  * time         (time) float64 9kB 0.0 0.025 0.05 0.075 ... 27.05 27.07 27.1
  * space        (space) <U1 8B 'x' 'y'
  * keypoints    (keypoints) <U7 56B 'head' 'stinger'
  * individuals  (individuals) <U12 48B 'individual_0'
Data variables:
    position     (time, space, keypoints, individuals) float64 35kB 1.086e+03...
    confidence   (time, keypoints, individuals) float64 17kB 0.05305 ... 0.0
Attributes:
    fps:              40.0
    time_unit:        seconds
    source_software:  DeepLabCut
    source_file:      /home/runner/.movement/data/poses/DLC_single-wasp.predi...
    ds_type:          poses
    frame_path:       /home/runner/.movement/data/frames/single-wasp_frame-10...
    video_path:       None


plot_raw_and_smooth_timeseries_and_psd(
    ds_wasp, ds_wasp_smooth, keypoint="stinger"
)
Time Domain, Frequency Domain

This example shows two important limitations of the Savitzky-Golay filter. First, the filter can introduce artefacts around sharp boundaries. For example, focus on what happens around the sudden drop in position during the final second. Second, the PSD appears to have large periodic drops at certain frequencies. Both of these effects vary with the choice of window and polyorder. You can read more about these and other limitations of the Savitzky-Golay filter in this paper.

Combining multiple smoothing filters#

We can also combine multiple smoothing filters by applying them sequentially. For example, we can first apply the median filter with a small window to remove “spikes” and then apply the Savitzky-Golay filter with a larger window to further smooth the data. Between the two filters, we can interpolate over small gaps to avoid the excessive proliferation of NaN values. Let’s try this on the mouse dataset.

# First, we will apply the median filter.
window = int(0.1 * ds_mouse.fps)
ds_mouse_smooth.update(
    {"position": median_filter(ds_mouse.position, window, min_periods=2)}
)

# Next, let's linearly interpolate over gaps smaller
# than 1 second (30 frames).
ds_mouse_smooth.update(
    {"position": interpolate_over_time(ds_mouse_smooth.position, max_gap=30)}
)

# Finally, let's apply the Savitzky-Golay filter
# over a 0.4-second window (12 frames).
window = int(0.4 * ds_mouse.fps)
ds_mouse_smooth.update(
    {"position": savgol_filter(ds_mouse_smooth.position, window)}
)
Missing points (marked as NaN) in input
        Individual: individual_0
                snout: 4494/18485 (24.3%)
                left_ear: 513/18485 (2.8%)
                right_ear: 533/18485 (2.9%)
                centre: 490/18485 (2.7%)
                tail_base: 704/18485 (3.8%)
                tail_end: 2496/18485 (13.5%)

Missing points (marked as NaN) in output
        Individual: individual_0
                snout: 4455/18485 (24.1%)
                left_ear: 487/18485 (2.6%)
                right_ear: 507/18485 (2.7%)
                centre: 465/18485 (2.5%)
                tail_base: 673/18485 (3.6%)
                tail_end: 2428/18485 (13.1%)

Missing points (marked as NaN) in input
        Individual: individual_0
                snout: 4455/18485 (24.1%)
                left_ear: 487/18485 (2.6%)
                right_ear: 507/18485 (2.7%)
                centre: 465/18485 (2.5%)
                tail_base: 673/18485 (3.6%)
                tail_end: 2428/18485 (13.1%)

Missing points (marked as NaN) in output
        Individual: individual_0
                snout: 3011/18485 (16.3%)
                left_ear: 257/18485 (1.4%)
                right_ear: 294/18485 (1.6%)
                centre: 257/18485 (1.4%)
                tail_base: 257/18485 (1.4%)
                tail_end: 1101/18485 (6.0%)

Missing points (marked as NaN) in input
        Individual: individual_0
                snout: 3011/18485 (16.3%)
                left_ear: 257/18485 (1.4%)
                right_ear: 294/18485 (1.6%)
                centre: 257/18485 (1.4%)
                tail_base: 257/18485 (1.4%)
                tail_end: 1101/18485 (6.0%)

Missing points (marked as NaN) in output
        Individual: individual_0
                snout: 3520/18485 (19.0%)
                left_ear: 306/18485 (1.7%)
                right_ear: 354/18485 (1.9%)
                centre: 306/18485 (1.7%)
                tail_base: 306/18485 (1.7%)
                tail_end: 1304/18485 (7.1%)
<xarray.Dataset> Size: 1MB
Dimensions:      (time: 18485, space: 2, keypoints: 6, individuals: 1)
Coordinates:
  * time         (time) float64 148kB 0.0 0.03333 0.06667 ... 616.1 616.1 616.1
  * space        (space) <U1 8B 'x' 'y'
  * keypoints    (keypoints) <U9 216B 'snout' 'left_ear' ... 'tail_end'
  * individuals  (individuals) <U12 48B 'individual_0'
Data variables:
    position     (time, space, keypoints, individuals) float32 887kB nan ... ...
    confidence   (time, keypoints, individuals) float32 444kB nan nan ... 0.7607
Attributes:
    fps:              30.0
    time_unit:        seconds
    source_software:  SLEAP
    source_file:      /home/runner/.movement/data/poses/SLEAP_single-mouse_EP...
    ds_type:          poses
    frame_path:       /home/runner/.movement/data/frames/single-mouse_EPM_fra...
    video_path:       None


A record of all applied operations is stored in the log attribute of the ds_mouse_smooth.position data array. Let’s inspect it to summarise what we’ve done.

for entry in ds_mouse_smooth.position.log:
    print(entry)
{'operation': 'median_filter', 'datetime': '2024-12-09 16:46:38.559265', 'arg_1': 3, 'min_periods': 2}
{'operation': 'interpolate_over_time', 'datetime': '2024-12-09 16:46:38.594802', 'max_gap': 30}
{'operation': 'savgol_filter', 'datetime': '2024-12-09 16:46:38.609792', 'arg_1': 12}

Now let’s visualise the difference between the raw data and the final smoothed result.

plot_raw_and_smooth_timeseries_and_psd(
    ds_mouse,
    ds_mouse_smooth,
    keypoint="snout",
    time_range=slice(0, 80),
)
Time Domain, Frequency Domain

Feel free to play around with the parameters of the applied filters and to also look at other keypoints and time ranges.

Total running time of the script: (0 minutes 1.509 seconds)

Gallery generated by Sphinx-Gallery