Source code for movement.plots.trajectory
"""Wrappers to plot movement data."""
import xarray as xr
from matplotlib import pyplot as plt
DEFAULT_PLOTTING_ARGS = {
"s": 15,
"marker": "o",
"alpha": 1.0,
}
[docs]
def plot_centroid_trajectory(
da: xr.DataArray,
individual: str | None = None,
keypoints: str | list[str] | None = None,
ax: plt.Axes | None = None,
**kwargs,
) -> tuple[plt.Figure, plt.Axes]:
"""Plot centroid trajectory.
This function plots the trajectory of the centroid
of multiple keypoints for a given individual. By default, the trajectory
is colored by time (using the default colormap). Pass a different colormap
through ``cmap`` if desired. If a single keypoint is passed, the trajectory
will be the same as the trajectory of the keypoint.
Parameters
----------
da : xr.DataArray
A data array containing position information, with `time` and `space`
as required dimensions. Optionally, it may have `individuals` and/or
`keypoints` dimensions.
individual : str, optional
The name of the individual to be plotted. By default, the first
individual is plotted.
keypoints : str, list[str], optional
The name of the keypoint to be plotted, or a list of keypoint names
(their centroid will be plotted). By default, the centroid of all
keypoints is plotted.
ax : matplotlib.axes.Axes or None, optional
Axes object on which to draw the trajectory. If None, a new
figure and axes are created.
**kwargs : dict
Additional keyword arguments passed to
``matplotlib.axes.Axes.scatter()``.
Returns
-------
(figure, axes) : tuple of (matplotlib.pyplot.Figure, matplotlib.axes.Axes)
The figure and axes containing the trajectory plot.
"""
if isinstance(individual, list):
raise ValueError("Only one individual can be selected.")
selection = {}
if "individuals" in da.dims:
if individual is None:
selection["individuals"] = da.individuals.values[0]
else:
selection["individuals"] = individual
if "keypoints" in da.dims:
if keypoints is None:
selection["keypoints"] = da.keypoints.values
else:
selection["keypoints"] = keypoints
plot_point = da.sel(**selection)
# If there are multiple selected keypoints, calculate the centroid
plot_point = (
plot_point.mean(dim="keypoints", skipna=True)
if "keypoints" in plot_point.dims and plot_point.sizes["keypoints"] > 1
else plot_point
)
plot_point = plot_point.squeeze() # Only space and time should remain
fig, ax = plt.subplots(figsize=(6, 6)) if ax is None else (ax.figure, ax)
# Merge default plotting args with user-provided kwargs
for key, value in DEFAULT_PLOTTING_ARGS.items():
kwargs.setdefault(key, value)
colorbar = False
if "c" not in kwargs:
kwargs["c"] = plot_point.time
colorbar = True
# Plot the scatter, colouring by time or user-provided colour
sc = ax.scatter(
plot_point.sel(space="x"),
plot_point.sel(space="y"),
**kwargs,
)
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.set_title("Trajectory")
# Add 'colorbar' for time dimension if no colour was provided by user
time_label = "Time"
fig.colorbar(sc, ax=ax, label=time_label).solids.set(
alpha=1.0
) if colorbar else None
return fig, ax