Source code for movement.kinematics.distances

"""Computing spatial relationships between points, such as distances."""

import itertools
from typing import Literal

import xarray as xr
from scipy.spatial.distance import cdist

from movement.utils.logging import logger
from movement.validators.arrays import validate_dims_coords


def _cdist(
    a: xr.DataArray,
    b: xr.DataArray,
    dim: Literal["individuals", "keypoints"],
    metric: str | None = "euclidean",
    **kwargs,
) -> xr.DataArray:
    """Compute distances between two position arrays across a given dimension.

    This function is a wrapper around :func:`scipy.spatial.distance.cdist`
    and computes the pairwise distances between the two input position arrays
    across the dimension specified by ``dim``.
    The dimension can be either ``individuals`` or ``keypoints``.
    The distances are computed using the specified ``metric``.

    Parameters
    ----------
    a : xarray.DataArray
        The first input data containing position information of a
        single individual or keypoint, with ``time``, ``space``
        (in Cartesian coordinates), and ``individuals`` or ``keypoints``
        (as specified by ``dim``) as required dimensions.
    b : xarray.DataArray
        The second input data containing position information of a
        single individual or keypoint, with ``time``, ``space``
        (in Cartesian coordinates), and ``individuals`` or ``keypoints``
        (as specified by ``dim``) as required dimensions.
    dim : str
        The dimension to compute the distances for. Must be either
        ``'individuals'`` or ``'keypoints'``.
    metric : str, optional
        The distance metric to use. Must be one of the options supported
        by :func:`scipy.spatial.distance.cdist`, e.g. ``'cityblock'``,
        ``'euclidean'``, etc.
        Defaults to ``'euclidean'``.
    **kwargs : dict
        Additional keyword arguments to pass to
        :func:`scipy.spatial.distance.cdist`.


    Returns
    -------
    xarray.DataArray
        An xarray DataArray containing the computed distances between
        each pair of inputs.

    Examples
    --------
    Compute the Euclidean distance (default) between ``ind1`` and
    ``ind2`` (i.e. interindividual distance for all keypoints)
    using the ``position`` data variable in the Dataset ``ds``:

    >>> pos1 = ds.position.sel(individuals="ind1")
    >>> pos2 = ds.position.sel(individuals="ind2")
    >>> ind_dists = _cdist(pos1, pos2, dim="individuals")

    Compute the Euclidean distance (default) between ``key1`` and
    ``key2`` (i.e. interkeypoint distance for all individuals)
    using the ``position`` data variable in the Dataset ``ds``:

    >>> pos1 = ds.position.sel(keypoints="key1")
    >>> pos2 = ds.position.sel(keypoints="key2")
    >>> key_dists = _cdist(pos1, pos2, dim="keypoints")

    See Also
    --------
    scipy.spatial.distance.cdist : The underlying function used.
    compute_pairwise_distances : Compute pairwise distances between
        ``individuals`` or ``keypoints``

    """
    # The dimension from which ``dim`` labels are obtained
    labels_dim = "individuals" if dim == "keypoints" else "keypoints"
    elem1 = getattr(a, dim).item()
    elem2 = getattr(b, dim).item()
    a = _validate_labels_dimension(a, labels_dim)
    b = _validate_labels_dimension(b, labels_dim)
    result = xr.apply_ufunc(
        cdist,
        a,
        b,
        kwargs={"metric": metric, **kwargs},
        input_core_dims=[[labels_dim, "space"], [labels_dim, "space"]],
        output_core_dims=[[elem1, elem2]],
        vectorize=True,
    )
    result = result.assign_coords(
        {
            elem1: getattr(a, labels_dim).values,
            elem2: getattr(a, labels_dim).values,
        }
    )
    result.name = "distance"
    # Drop any squeezed coordinates
    return result.squeeze(drop=True)


[docs] def compute_pairwise_distances( data: xr.DataArray, dim: Literal["individuals", "keypoints"], pairs: dict[str, str | list[str]] | Literal["all"], metric: str | None = "euclidean", **kwargs, ) -> xr.DataArray | dict[str, xr.DataArray]: """Compute pairwise distances between ``individuals`` or ``keypoints``. This function computes the distances between pairs of ``individuals`` (i.e. interindividual distances) or pairs of ``keypoints`` (i.e. interkeypoint distances), as determined by ``dim``. The distances are computed for the given ``pairs`` using the specified ``metric``. Parameters ---------- data : xarray.DataArray The input data containing position information, with ``time``, ``space`` (in Cartesian coordinates), and ``individuals`` or ``keypoints`` (as specified by ``dim``) as required dimensions. dim : Literal["individuals", "keypoints"] The dimension to compute the distances for. Must be either ``'individuals'`` or ``'keypoints'``. pairs : dict[str, str | list[str]] or 'all' Specifies the pairs of elements (either individuals or keypoints) for which to compute distances, depending on the value of ``dim``. - If ``dim='individuals'``, ``pairs`` should be a dictionary where each key is an individual name, and each value is also an individual name or a list of such names to compute distances with. - If ``dim='keypoints'``, ``pairs`` should be a dictionary where each key is a keypoint name, and each value is also keypoint name or a list of such names to compute distances with. - Alternatively, use the special keyword ``'all'`` to compute distances for all possible pairs of individuals or keypoints (depending on ``dim``). metric : str, optional The distance metric to use. Must be one of the options supported by :func:`scipy.spatial.distance.cdist`, e.g. ``'cityblock'``, ``'euclidean'``, etc. Defaults to ``'euclidean'``. **kwargs : dict Additional keyword arguments to pass to :func:`scipy.spatial.distance.cdist`. Returns ------- xarray.DataArray or dict[str, xarray.DataArray] The computed pairwise distances. If a single pair is specified in ``pairs``, returns an :class:`xarray.DataArray`. If multiple pairs are specified, returns a dictionary where each key is a string representing the pair (e.g., ``'dist_ind1_ind2'`` or ``'dist_key1_key2'``) and each value is an :class:`xarray.DataArray` containing the computed distances for that pair. Raises ------ ValueError If ``dim`` is not one of ``'individuals'`` or ``'keypoints'``; if ``pairs`` is not a dictionary or ``'all'``; or if there are no pairs in ``data`` to compute distances for. Examples -------- Compute the Euclidean distance (default) between ``ind1`` and ``ind2`` (i.e. interindividual distance), for all possible pairs of keypoints. >>> position = xr.DataArray( ... np.arange(36).reshape(2, 3, 3, 2), ... coords={ ... "time": np.arange(2), ... "individuals": ["ind1", "ind2", "ind3"], ... "keypoints": ["key1", "key2", "key3"], ... "space": ["x", "y"], ... }, ... dims=["time", "individuals", "keypoints", "space"], ... ) >>> dist_ind1_ind2 = compute_pairwise_distances( ... position, "individuals", {"ind1": "ind2"} ... ) >>> dist_ind1_ind2 <xarray.DataArray (time: 2, ind1: 3, ind2: 3)> Size: 144B 8.485 11.31 14.14 5.657 8.485 11.31 ... 5.657 8.485 11.31 2.828 5.657 8.485 Coordinates: * time (time) int64 16B 0 1 * ind1 (ind1) <U4 48B 'key1' 'key2' 'key3' * ind2 (ind2) <U4 48B 'key1' 'key2' 'key3' The resulting ``dist_ind1_ind2`` is a DataArray containing the computed distances between ``ind1`` and ``ind2`` for all keypoints at each time point. To obtain the distances between ``key1`` of ``ind1`` and ``key2`` of ``ind2``: >>> dist_ind1_ind2.sel(ind1="key1", ind2="key2") Compute the Euclidean distance (default) between ``key1`` and ``key2`` (i.e. interkeypoint distance), for all possible pairs of individuals. >>> dist_key1_key2 = compute_pairwise_distances( ... position, "keypoints", {"key1": "key2"} ... ) >>> dist_key1_key2 <xarray.DataArray (time: 2, key1: 3, key2: 3)> Size: 144B 2.828 11.31 19.8 5.657 2.828 11.31 14.14 ... 2.828 11.31 14.14 5.657 2.828 Coordinates: * time (time) int64 16B 0 1 * key1 (key1) <U4 48B 'ind1' 'ind2' 'ind3' * key2 (key2) <U4 48B 'ind1' 'ind2' 'ind3' The resulting ``dist_key1_key2`` is a DataArray containing the computed distances between ``key1`` and ``key2`` for all individuals at each time point. To obtain the distances between ``key1`` and ``key2`` within ``ind1``: >>> dist_key1_key2.sel(key1="ind1", key2="ind1") To obtain the distances between ``key1`` of ``ind1`` and ``key2`` of ``ind2``: >>> dist_key1_key2.sel(key1="ind1", key2="ind2") Compute the city block or Manhattan distance for multiple pairs of keypoints using ``position``: >>> key_dists = compute_pairwise_distances( ... position, ... "keypoints", ... {"key1": "key2", "key3": ["key1", "key2"]}, ... metric="cityblock", ... ) >>> key_dists.keys() dict_keys(['dist_key1_key2', 'dist_key3_key1', 'dist_key3_key2']) As multiple pairs of keypoints are specified, the resulting ``key_dists`` is a dictionary containing the DataArrays of computed distances for each pair of keypoints. Compute the city block or Manhattan distance for all possible pairs of individuals using ``position``: >>> ind_dists = compute_pairwise_distances( ... position, ... "individuals", ... "all", ... metric="cityblock", ... ) >>> ind_dists.keys() dict_keys(['dist_ind1_ind2', 'dist_ind1_ind3', 'dist_ind2_ind3']) See Also -------- scipy.spatial.distance.cdist : The underlying function used. """ if dim not in ["individuals", "keypoints"]: raise logger.error( ValueError( "'dim' must be either 'individuals' or 'keypoints', " f"but got {dim}." ) ) if isinstance(pairs, str) and pairs != "all": raise logger.error( ValueError( f"'pairs' must be a dictionary or 'all', but got {pairs}." ) ) validate_dims_coords(data, {"time": [], "space": ["x", "y"], dim: []}) # Find all possible pair combinations if 'all' is specified if pairs == "all": paired_elements = list( itertools.combinations(getattr(data, dim).values, 2) ) else: paired_elements = [ (elem1, elem2) for elem1, elem2_list in pairs.items() for elem2 in ( # Ensure elem2_list is a list [elem2_list] if isinstance(elem2_list, str) else elem2_list ) ] if not paired_elements: raise logger.error( ValueError("Could not find any pairs to compute distances for.") ) pairwise_distances = { f"dist_{elem1}_{elem2}": _cdist( data.sel({dim: elem1}), data.sel({dim: elem2}), dim=dim, metric=metric, **kwargs, ) for elem1, elem2 in paired_elements } # Return DataArray if result only has one key if len(pairwise_distances) == 1: return next(iter(pairwise_distances.values())) return pairwise_distances
def _validate_labels_dimension(data: xr.DataArray, dim: str) -> xr.DataArray: """Validate the input data contains the ``dim`` for labelling dimensions. This function ensures the input data contains the ``dim`` used as labels (coordinates) when applying :func:`scipy.spatial.distance.cdist` to the input data, by adding a temporary dimension if necessary. Parameters ---------- data : xarray.DataArray The input data to validate. dim : str The dimension to validate. Returns ------- xarray.DataArray The input data with the labels dimension validated. """ if data.coords.get(dim) is None: data = data.assign_coords({dim: "temp_dim"}) if data.coords[dim].ndim == 0: data = data.expand_dims(dim).transpose("time", "space", dim) return data