Note
Go to the end to download the full example code. or to run this example in your browser via Binder
Smooth pose tracks#
Smooth pose tracks using the median and Savitzky-Golay filters.
Imports#
from matplotlib import pyplot as plt
from scipy.signal import welch
from movement import sample_data
from movement.filtering import (
interpolate_over_time,
median_filter,
savgol_filter,
)
Load a sample dataset#
Let’s load a sample dataset and print it to inspect its contents.
Note that if you are running this notebook interactively, you can simply
type the variable name (here ds_wasp
) in a cell to get an interactive
display of the dataset’s contents.
ds_wasp = sample_data.fetch_dataset("DLC_single-wasp.predictions.h5")
print(ds_wasp)
<xarray.Dataset> Size: 61kB
Dimensions: (time: 1085, space: 2, keypoints: 2, individuals: 1)
Coordinates:
* time (time) float64 9kB 0.0 0.025 0.05 0.075 ... 27.05 27.07 27.1
* space (space) <U1 8B 'x' 'y'
* keypoints (keypoints) <U7 56B 'head' 'stinger'
* individuals (individuals) <U12 48B 'individual_0'
Data variables:
position (time, space, keypoints, individuals) float64 35kB 1.086e+03...
confidence (time, keypoints, individuals) float64 17kB 0.05305 ... 0.0
Attributes:
fps: 40.0
time_unit: seconds
source_software: DeepLabCut
source_file: /home/runner/.movement/data/poses/DLC_single-wasp.predi...
ds_type: poses
frame_path: /home/runner/.movement/data/frames/single-wasp_frame-10...
video_path: None
We see that the dataset contains the 2D pose tracks and confidence scores for a single wasp, generated with DeepLabCut. The wasp is tracked at two keypoints: “head” and “stinger” in a video that was recorded at 40 fps and lasts for approximately 27 seconds.
Define a plotting function#
Let’s define a plotting function to help us visualise the effects of
smoothing both in the time and frequency domains.
The function takes as inputs two datasets containing raw and smooth data
respectively, and plots the position time series and power spectral density
(PSD) for a given individual and keypoint. The function also allows you to
specify the spatial coordinate (x
or y
) and a time range to focus on.
def plot_raw_and_smooth_timeseries_and_psd(
ds_raw,
ds_smooth,
individual="individual_0",
keypoint="stinger",
space="x",
time_range=None,
):
# If no time range is specified, plot the entire time series
if time_range is None:
time_range = slice(0, ds_raw.time[-1])
selection = {
"time": time_range,
"individuals": individual,
"keypoints": keypoint,
"space": space,
}
fig, ax = plt.subplots(2, 1, figsize=(10, 6))
for ds, color, label in zip(
[ds_raw, ds_smooth], ["k", "r"], ["raw", "smooth"], strict=False
):
# plot position time series
pos = ds.position.sel(**selection)
ax[0].plot(
pos.time,
pos,
color=color,
lw=2,
alpha=0.7,
label=f"{label} {space}",
)
# interpolate data to remove NaNs in the PSD calculation
pos_interp = interpolate_over_time(pos, print_report=False)
# compute and plot the PSD
freq, psd = welch(pos_interp, fs=ds.fps, nperseg=256)
ax[1].semilogy(
freq,
psd,
color=color,
lw=2,
alpha=0.7,
label=f"{label} {space}",
)
ax[0].set_ylabel(f"{space} position (px)")
ax[0].set_xlabel("Time (s)")
ax[0].set_title("Time Domain")
ax[0].legend()
ax[1].set_ylabel("PSD (px$^2$/Hz)")
ax[1].set_xlabel("Frequency (Hz)")
ax[1].set_title("Frequency Domain")
ax[1].legend()
plt.tight_layout()
fig.show()
Smoothing with a median filter#
Using the movement.filtering.median_filter()
function on the
position
data variable, we can apply a rolling window median filter
over a 0.1-second window (4 frames) to the wasp dataset.
As the window
parameter is defined in number of observations,
we can simply multiply the desired time window by the frame rate
of the video. We will also create a copy of the dataset to avoid
modifying the original data.
window = int(0.1 * ds_wasp.fps)
ds_wasp_smooth = ds_wasp.copy()
ds_wasp_smooth.update({"position": median_filter(ds_wasp.position, window)})
Missing points (marked as NaN) in input
Individual: individual_0
head: 0/1085 (0.0%)
stinger: 0/1085 (0.0%)
Missing points (marked as NaN) in output
Individual: individual_0
head: 0/1085 (0.0%)
stinger: 0/1085 (0.0%)
We see from the printed report that the dataset has no missing values neither before nor after smoothing. Let’s visualise the effects of the median filter in the time and frequency domains.
plot_raw_and_smooth_timeseries_and_psd(
ds_wasp, ds_wasp_smooth, keypoint="stinger"
)
We see that the median filter has removed the “spikes” present around the 14 second mark in the raw data. However, it has not dealt the big shift occurring during the final second. In the frequency domain, we can see that the filter has reduced the power in the high-frequency components, without affecting the low frequency components.
This illustrates what the median filter is good at: removing brief “spikes” (e.g. a keypoint abruptly jumping to a different location for a frame or two) and high-frequency “jitter” (often present due to pose estimation working on a per-frame basis).
Choosing parameters for the median filter#
We can control the behaviour of the median filter
via two parameters: window
and min_periods
.
To better understand the effect of these parameters, let’s use a
dataset that contains missing values.
ds_mouse = sample_data.fetch_dataset("SLEAP_single-mouse_EPM.analysis.h5")
print(ds_mouse)
<xarray.Dataset> Size: 1MB
Dimensions: (time: 18485, space: 2, keypoints: 6, individuals: 1)
Coordinates:
* time (time) float64 148kB 0.0 0.03333 0.06667 ... 616.1 616.1 616.1
* space (space) <U1 8B 'x' 'y'
* keypoints (keypoints) <U9 216B 'snout' 'left_ear' ... 'tail_end'
* individuals (individuals) <U12 48B 'individual_0'
Data variables:
position (time, space, keypoints, individuals) float32 887kB nan ... ...
confidence (time, keypoints, individuals) float32 444kB nan nan ... 0.7607
Attributes:
fps: 30.0
time_unit: seconds
source_software: SLEAP
source_file: /home/runner/.movement/data/poses/SLEAP_single-mouse_EP...
ds_type: poses
frame_path: /home/runner/.movement/data/frames/single-mouse_EPM_fra...
video_path: None
The dataset contains a single mouse with six keypoints tracked in 2D space. The video was recorded at 30 fps and lasts for ~616 seconds. We can see that there are some missing values, indicated as “nan” in the printed dataset. Let’s apply the median filter over a 0.1-second window (3 frames) to the dataset.
window = int(0.1 * ds_mouse.fps)
ds_mouse_smooth = ds_mouse.copy()
ds_mouse_smooth.update({"position": median_filter(ds_mouse.position, window)})
Missing points (marked as NaN) in input
Individual: individual_0
snout: 4494/18485 (24.3%)
left_ear: 513/18485 (2.8%)
right_ear: 533/18485 (2.9%)
centre: 490/18485 (2.7%)
tail_base: 704/18485 (3.8%)
tail_end: 2496/18485 (13.5%)
Missing points (marked as NaN) in output
Individual: individual_0
snout: 5106/18485 (27.6%)
left_ear: 678/18485 (3.7%)
right_ear: 695/18485 (3.8%)
centre: 640/18485 (3.5%)
tail_base: 913/18485 (4.9%)
tail_end: 3103/18485 (16.8%)
The report informs us that the raw data contains NaN values, most of which
occur at the snout
and tail_end
keypoints. After filtering, the
number of NaNs has increased. This is because the default behaviour of the
median filter is to propagate NaN values, i.e. if any value in the rolling
window is NaN, the output will also be NaN.
To modify this behaviour, you can set the value of the min_periods
parameter to an integer value. This parameter determines the minimum number
of non-NaN values required in the window for the output to be non-NaN.
For example, setting min_periods=2
means that two non-NaN values in the
window are sufficient for the median to be calculated. Let’s try this.
ds_mouse_smooth.update(
{"position": median_filter(ds_mouse.position, window, min_periods=2)}
)
Missing points (marked as NaN) in input
Individual: individual_0
snout: 4494/18485 (24.3%)
left_ear: 513/18485 (2.8%)
right_ear: 533/18485 (2.9%)
centre: 490/18485 (2.7%)
tail_base: 704/18485 (3.8%)
tail_end: 2496/18485 (13.5%)
Missing points (marked as NaN) in output
Individual: individual_0
snout: 4455/18485 (24.1%)
left_ear: 487/18485 (2.6%)
right_ear: 507/18485 (2.7%)
centre: 465/18485 (2.5%)
tail_base: 673/18485 (3.6%)
tail_end: 2428/18485 (13.1%)
We see that this time the number of NaN values has decreased
across all keypoints.
Let’s visualise the effects of the median filter in the time and frequency
domains. Here we focus on the first 80 seconds for the snout
keypoint.
You can adjust the keypoint
and time_range
arguments to explore other
parts of the data.
plot_raw_and_smooth_timeseries_and_psd(
ds_mouse, ds_mouse_smooth, keypoint="snout", time_range=slice(0, 80)
)
The smoothing once again reduces the power of high-frequency components, but the resulting time series stays quite close to the raw data.
What happens if we increase the window
to 2 seconds (60 frames)?
window = int(2 * ds_mouse.fps)
ds_mouse_smooth.update(
{"position": median_filter(ds_mouse.position, window, min_periods=2)}
)
Missing points (marked as NaN) in input
Individual: individual_0
snout: 4494/18485 (24.3%)
left_ear: 513/18485 (2.8%)
right_ear: 533/18485 (2.9%)
centre: 490/18485 (2.7%)
tail_base: 704/18485 (3.8%)
tail_end: 2496/18485 (13.5%)
Missing points (marked as NaN) in output
Individual: individual_0
snout: 795/18485 (4.3%)
left_ear: 80/18485 (0.4%)
right_ear: 80/18485 (0.4%)
centre: 80/18485 (0.4%)
tail_base: 80/18485 (0.4%)
tail_end: 239/18485 (1.3%)
The number of NaN values has decreased even further. That’s because the chance of finding at least 2 valid values within a 2-second window (i.e. 60 frames) is quite high. Let’s plot the results for the same keypoint and time range as before.
plot_raw_and_smooth_timeseries_and_psd(
ds_mouse, ds_mouse_smooth, keypoint="snout", time_range=slice(0, 80)
)
We see that the filtered time series is much smoother and it has even
“bridged” over some small gaps. That said, it often deviates from the raw
data, in ways that may not be desirable, depending on the application.
Here, our choice of window
may be too large.
In general, you should choose a window
that is small enough to
preserve the original data structure, but large enough to remove
“spikes” and high-frequency noise. Always inspect the results to ensure
that the filter is not removing important features.
Smoothing with a Savitzky-Golay filter#
Here we apply the movement.filtering.savgol_filter()
function
(a wrapper around scipy.signal.savgol_filter()
), to the position
data variable.
The Savitzky-Golay filter is a polynomial smoothing filter that can be
applied to time series data on a rolling window basis.
A polynomial with a degree specified by polyorder
is applied to each
data segment defined by the size window
.
The value of the polynomial at the midpoint of each window
is then
used as the output value.
Let’s try it on the mouse dataset, this time using a 0.2-second
window (i.e. 6 frames) and the default polyorder=2
for smoothing.
As before, we first compute the corresponding number of observations
to be used as the window
size.
window = int(0.2 * ds_mouse.fps)
ds_mouse_smooth.update({"position": savgol_filter(ds_mouse.position, window)})
Missing points (marked as NaN) in input
Individual: individual_0
snout: 4494/18485 (24.3%)
left_ear: 513/18485 (2.8%)
right_ear: 533/18485 (2.9%)
centre: 490/18485 (2.7%)
tail_base: 704/18485 (3.8%)
tail_end: 2496/18485 (13.5%)
Missing points (marked as NaN) in output
Individual: individual_0
snout: 5810/18485 (31.4%)
left_ear: 895/18485 (4.8%)
right_ear: 905/18485 (4.9%)
centre: 839/18485 (4.5%)
tail_base: 1186/18485 (6.4%)
tail_end: 3801/18485 (20.6%)
We see that the number of NaN values has increased after filtering. This is
for the same reason as with the median filter (in its default mode), i.e.
if there is at least one NaN value in the window, the output will be NaN.
Unlike the median filter, the Savitzky-Golay filter does not provide a
min_periods
parameter to control this behaviour. Let’s visualise the
effects in the time and frequency domains.
plot_raw_and_smooth_timeseries_and_psd(
ds_mouse, ds_mouse_smooth, keypoint="snout", time_range=slice(0, 80)
)
Once again, the power of high-frequency components has been reduced, but more missing values have been introduced.
Now let’s apply the same Savitzky-Golay filter to the wasp dataset.
window = int(0.2 * ds_wasp.fps)
ds_wasp_smooth.update({"position": savgol_filter(ds_wasp.position, window)})
Missing points (marked as NaN) in input
Individual: individual_0
head: 0/1085 (0.0%)
stinger: 0/1085 (0.0%)
Missing points (marked as NaN) in output
Individual: individual_0
head: 0/1085 (0.0%)
stinger: 0/1085 (0.0%)
plot_raw_and_smooth_timeseries_and_psd(
ds_wasp, ds_wasp_smooth, keypoint="stinger"
)
This example shows two important limitations of the Savitzky-Golay filter.
First, the filter can introduce artefacts around sharp boundaries. For
example, focus on what happens around the sudden drop in position
during the final second. Second, the PSD appears to have large periodic
drops at certain frequencies. Both of these effects vary with the
choice of window
and polyorder
. You can read more about these
and other limitations of the Savitzky-Golay filter in
this paper.
Combining multiple smoothing filters#
We can also combine multiple smoothing filters by applying them
sequentially. For example, we can first apply the median filter with a small
window
to remove “spikes” and then apply the Savitzky-Golay filter
with a larger window
to further smooth the data.
Between the two filters, we can interpolate over small gaps to avoid the
excessive proliferation of NaN values. Let’s try this on the mouse dataset.
# First, we will apply the median filter.
window = int(0.1 * ds_mouse.fps)
ds_mouse_smooth.update(
{"position": median_filter(ds_mouse.position, window, min_periods=2)}
)
# Next, let's linearly interpolate over gaps smaller
# than 1 second (30 frames).
ds_mouse_smooth.update(
{"position": interpolate_over_time(ds_mouse_smooth.position, max_gap=30)}
)
# Finally, let's apply the Savitzky-Golay filter
# over a 0.4-second window (12 frames).
window = int(0.4 * ds_mouse.fps)
ds_mouse_smooth.update(
{"position": savgol_filter(ds_mouse_smooth.position, window)}
)
Missing points (marked as NaN) in input
Individual: individual_0
snout: 4494/18485 (24.3%)
left_ear: 513/18485 (2.8%)
right_ear: 533/18485 (2.9%)
centre: 490/18485 (2.7%)
tail_base: 704/18485 (3.8%)
tail_end: 2496/18485 (13.5%)
Missing points (marked as NaN) in output
Individual: individual_0
snout: 4455/18485 (24.1%)
left_ear: 487/18485 (2.6%)
right_ear: 507/18485 (2.7%)
centre: 465/18485 (2.5%)
tail_base: 673/18485 (3.6%)
tail_end: 2428/18485 (13.1%)
Missing points (marked as NaN) in input
Individual: individual_0
snout: 4455/18485 (24.1%)
left_ear: 487/18485 (2.6%)
right_ear: 507/18485 (2.7%)
centre: 465/18485 (2.5%)
tail_base: 673/18485 (3.6%)
tail_end: 2428/18485 (13.1%)
Missing points (marked as NaN) in output
Individual: individual_0
snout: 3011/18485 (16.3%)
left_ear: 257/18485 (1.4%)
right_ear: 294/18485 (1.6%)
centre: 257/18485 (1.4%)
tail_base: 257/18485 (1.4%)
tail_end: 1101/18485 (6.0%)
Missing points (marked as NaN) in input
Individual: individual_0
snout: 3011/18485 (16.3%)
left_ear: 257/18485 (1.4%)
right_ear: 294/18485 (1.6%)
centre: 257/18485 (1.4%)
tail_base: 257/18485 (1.4%)
tail_end: 1101/18485 (6.0%)
Missing points (marked as NaN) in output
Individual: individual_0
snout: 3520/18485 (19.0%)
left_ear: 306/18485 (1.7%)
right_ear: 354/18485 (1.9%)
centre: 306/18485 (1.7%)
tail_base: 306/18485 (1.7%)
tail_end: 1304/18485 (7.1%)
A record of all applied operations is stored in the log
attribute of the
ds_mouse_smooth.position
data array. Let’s inspect it to summarise
what we’ve done.
for entry in ds_mouse_smooth.position.log:
print(entry)
{'operation': 'median_filter', 'datetime': '2024-12-09 16:46:38.559265', 'arg_1': 3, 'min_periods': 2}
{'operation': 'interpolate_over_time', 'datetime': '2024-12-09 16:46:38.594802', 'max_gap': 30}
{'operation': 'savgol_filter', 'datetime': '2024-12-09 16:46:38.609792', 'arg_1': 12}
Now let’s visualise the difference between the raw data and the final smoothed result.
plot_raw_and_smooth_timeseries_and_psd(
ds_mouse,
ds_mouse_smooth,
keypoint="snout",
time_range=slice(0, 80),
)
Feel free to play around with the parameters of the applied filters and to also look at other keypoints and time ranges.
See also
Filtering multiple data variables in the Drop outliers and interpolate example.
Total running time of the script: (0 minutes 1.509 seconds)