Source code for movement.plots.trajectory

"""Wrappers to plot movement data."""

import xarray as xr
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from matplotlib.figure import Figure, SubFigure

DEFAULT_PLOTTING_ARGS = {
    "s": 15,
    "marker": "o",
    "alpha": 1.0,
}


[docs] def plot_centroid_trajectory( da: xr.DataArray, individual: str | None = None, keypoints: str | list[str] | None = None, ax: Axes | None = None, **kwargs, ) -> tuple[Figure | SubFigure, Axes]: """Plot centroid trajectory. This function plots the trajectory of the centroid of multiple keypoints for a given individual. By default, the trajectory is colored by time (using the default colormap). Pass a different colormap through ``cmap`` if desired. If a single keypoint is passed, the trajectory will be the same as the trajectory of the keypoint. Parameters ---------- da : xr.DataArray A data array containing position information, with `time` and `space` as required dimensions. Optionally, it may have `individuals` and/or `keypoints` dimensions. individual : str, optional The name of the individual to be plotted. By default, the first individual is plotted. keypoints : str, list[str], optional The name of the keypoint to be plotted, or a list of keypoint names (their centroid will be plotted). By default, the centroid of all keypoints is plotted. ax : matplotlib.axes.Axes or None, optional Axes object on which to draw the trajectory. If None, a new figure and axes are created. **kwargs : dict Additional keyword arguments passed to :meth:`matplotlib.axes.Axes.scatter`. Returns ------- fig : matplotlib.figure.Figure or matplotlib.figure.SubFigure If ``ax`` is provided, this is ``ax.figure`` (:class:`matplotlib.figure.Figure` or :class:`matplotlib.figure.SubFigure`). Otherwise, a new :class:`matplotlib.figure.Figure` is created and returned. ax : matplotlib.axes.Axes Axes on which the trajectory was drawn. If ``ax`` is provided, the input will be directly modified and returned in this value. """ if isinstance(individual, list): raise ValueError("Only one individual can be selected.") selection = {} if "individuals" in da.dims: selection["individuals"] = individual or da.individuals.values[0] if "keypoints" in da.dims: selection["keypoints"] = keypoints or da.keypoints.values plot_point = da.sel(selection) # If there are multiple selected keypoints, calculate the centroid plot_point = ( plot_point.mean(dim="keypoints", skipna=True) if "keypoints" in plot_point.dims and plot_point.sizes["keypoints"] > 1 else plot_point ) plot_point = plot_point.squeeze() # Only space and time should remain fig, ax = plt.subplots(figsize=(6, 6)) if ax is None else (ax.figure, ax) # Merge default plotting args with user-provided kwargs c_provided = "c" in kwargs kwargs = {**DEFAULT_PLOTTING_ARGS, **kwargs} kwargs.setdefault("c", plot_point.time) # Plot the scatter, colouring by time or user-provided colour sc = ax.scatter( plot_point.sel(space="x"), plot_point.sel(space="y"), **kwargs, ) # Add 'colorbar' for time dimension if no colour was provided by user if not c_provided: cbar = fig.colorbar(sc, ax=ax, label="Time") if cbar.solids is not None: cbar.solids.set(alpha=1.0) ax.set_xlabel("x") ax.set_ylabel("y") ax.set_title("Trajectory") return fig, ax