"""Compute and apply spatial transforms."""
import itertools
import cv2
import numpy as np
import xarray as xr
from numpy.typing import ArrayLike
from movement.utils.logging import log_to_attrs
from movement.validators.arrays import validate_dims_coords
[docs]
@log_to_attrs
def scale(
data: xr.DataArray,
factor: ArrayLike | float,
space_unit: str | None = None,
) -> xr.DataArray:
"""Scale data by a given factor with an optional unit.
Parameters
----------
data
The input data to be scaled.
factor
The scaling factor to apply to the data. If factor is a scalar (a
single float), the data array is uniformly scaled by the same factor.
If factor is an object that can be converted to a 1D numpy array (e.g.
a list of floats), the length of the resulting array must match the
length of data array's space dimension along which it will be
broadcasted.
space_unit
The unit of the scaled data stored as a property in
``xarray.DataArray.attrs['space_unit']``. In case of the default
(``None``) the ``space_unit`` attribute is dropped.
Returns
-------
xarray.DataArray
The scaled data array.
Notes
-----
This function makes two changes to the resulting data array's attributes
(:attr:`xarray.DataArray.attrs`) each time it is called:
- It sets the ``space_unit`` attribute to the value of the parameter
with the same name, or removes it if ``space_unit=None``.
- It adds a new entry to the ``log`` attribute of the data array, which
contains a record of the operations performed, including the
parameters used, as well as the datetime of the function call.
Examples
--------
Let's imagine a camera viewing a 2D plane from the top, with an
estimated resolution of 10 pixels per cm. We can scale down
position data by a factor of 1/10 to express it in cm units.
>>> from movement.transforms import scale
>>> ds["position"] = scale(ds["position"], factor=1 / 10, space_unit="cm")
>>> print(ds["position"].space_unit)
cm
>>> print(ds["position"].log)
[
{
"operation": "scale",
"datetime": "2025-06-05 15:08:16.919947",
"factor": "0.1",
"space_unit": "'cm'"
}
]
Note that the attributes of the scaled data array now contain the assigned
``space_unit`` as well as a ``log`` entry with the arguments passed to
the function.
We can also scale the two spatial dimensions by different factors.
>>> ds["position"] = scale(ds["position"], factor=[10, 20])
The second scale operation restored the x axis to its original scale,
and scaled up the y axis to twice its original size.
The log will now contain two entries, but the ``space_unit`` attribute
has been removed, as it was not provided in the second function call.
>>> "space_unit" in ds["position"].attrs
False
"""
if len(data.coords["space"]) == 2:
validate_dims_coords(data, {"space": ["x", "y"]})
else:
validate_dims_coords(data, {"space": ["x", "y", "z"]})
if not np.isscalar(factor):
factor = np.array(factor).squeeze()
if factor.ndim != 1:
raise ValueError(
"Factor must be an object that can be converted to a 1D numpy"
f" array, got {factor.ndim}D"
)
elif factor.shape != data.space.values.shape:
raise ValueError(
f"Factor shape {factor.shape} does not match the shape "
f"of the space dimension {data.space.values.shape}"
)
else:
factor_dims = [1] * data.ndim # 1s array matching data dimensions
factor_dims[data.get_axis_num("space")] = factor.shape[0]
factor = factor.reshape(factor_dims)
scaled_data = data * factor
if space_unit is not None:
scaled_data.attrs["space_unit"] = space_unit
elif space_unit is None:
scaled_data.attrs.pop("space_unit", None)
return scaled_data
[docs]
def poses_to_bboxes(
position: xr.DataArray,
padding: float = 0.0,
) -> tuple[xr.DataArray, xr.DataArray]:
"""Compute bounding box centroid and shape from a poses position array.
This function computes bounding boxes from pose estimation keypoints by
finding the minimum and maximum coordinates across all keypoints for each
individual at each time point. The resulting bounding box is represented
by its centroid (center point) and shape (width and height).
Parameters
----------
position : xarray.DataArray
A 2D poses position array with dimensions
``(time, space, keypoints, individuals)``, where the ``space``
coordinate contains exactly ``["x", "y"]``.
padding : float, optional
Number of pixels to add as padding around the bounding box in all
directions. The padding increases both width and height by
``2 * padding``. Default is 0.0 (no padding).
Returns
-------
tuple[xarray.DataArray, xarray.DataArray]
A tuple ``(position, shape)`` where:
- ``position``: bounding box centroids with dimensions
``(time, space, individuals)``.
- ``shape``: bounding box width and height with dimensions
``(time, space, individuals)``.
Raises
------
TypeError
If ``position`` is not an :class:`xarray.DataArray` or if
``padding`` is not numeric.
ValueError
If the position array is missing required dimensions or coordinates,
is not 2D, or ``padding`` is negative.
Notes
-----
- Keypoints with NaN in any spatial coordinate are excluded from bounding
box calculation. If all keypoints for an individual at a given time are
NaN, the resulting centroid and shape are NaN.
- The centroid is calculated as the midpoint of the bounding box:
``(min + max) / 2`` for each spatial dimension.
- The shape is calculated as the span of coordinates plus padding:
``width = max_x - min_x + 2*padding`` and
``height = max_y - min_y + 2*padding``.
- When there is only one valid keypoint, the bounding box will have
zero width and/or height (before padding is applied).
Examples
--------
Compute bounding boxes from a poses dataset with zero padding:
>>> from movement.transforms import poses_to_bboxes
>>> bbox_position, bbox_shape = poses_to_bboxes(poses_ds["position"])
Compute bounding boxes from a poses dataset with 10 pixels of padding:
>>> bbox_position, bbox_shape = poses_to_bboxes(
... poses_ds["position"], padding=10
... )
See Also
--------
movement.transforms.scale : Scale spatial coordinates
"""
if not isinstance(position, xr.DataArray):
raise TypeError(
f"Expected an xarray DataArray, but got {type(position)}."
)
validate_dims_coords(
position,
{"time": [], "space": ["x", "y"], "keypoints": [], "individuals": []},
exact_coords=True,
)
if not isinstance(padding, int | float):
raise TypeError(
f"padding must be a number, got {type(padding).__name__}"
)
if padding < 0:
raise ValueError(f"padding must be non-negative, got {padding}")
# A keypoint is valid only if all spatial coordinates are present.
valid_mask = ~position.isnull().any(dim="space")
masked = position.where(valid_mask)
pos_min = masked.min(dim="keypoints", skipna=True)
pos_max = masked.max(dim="keypoints", skipna=True)
centroid = (pos_min + pos_max) / 2
shape = pos_max - pos_min + 2 * padding
return centroid, shape
def _validate_points_shape(src_points: np.ndarray, dst_points: np.ndarray):
"""Validate that source and destination point arrays.
The arrays should have matching 2D shapes.
"""
if len(src_points.shape) != 2 or len(dst_points.shape) != 2:
raise ValueError("Points must be 2-dimensional arrays.")
if src_points.shape != dst_points.shape:
raise ValueError(
"Source and destination points must have the same shape."
)
dim = src_points.shape[1]
if dim != 2:
raise ValueError("Points must be 2-dimensional.")
def _filter_invalid_points(src_pts: np.ndarray, dst_pts: np.ndarray):
"""Remove invalid points.
Invalid points are duplicate, degenerate, or
collinear point pairs from the input sets.
"""
keep_idx: list[int] = []
obtained_min_non_collinear_set = False
eps = 1e-6
for i in range(len(src_pts)):
# skip duplicates
if any(
np.linalg.norm(src_pts[i] - src_pts[j]) < eps for j in keep_idx
):
continue
subset = np.vstack([src_pts[j] for j in keep_idx] + [src_pts[i]])
if subset.shape[0] < 3:
keep_idx.append(i)
continue
elif subset.shape[0] == 3 and _is_collinear_set(subset, eps):
continue
# If we have at least 3 old points, check that
# new point is not collinear with any other two
if not obtained_min_non_collinear_set and subset.shape[0] > 3:
all_noncollinear_triples = all(
not _is_collinear_three(
src_pts[a], src_pts[b], src_pts[i], eps
)
for a, b in itertools.combinations(keep_idx, 2)
)
if all_noncollinear_triples:
obtained_min_non_collinear_set = True
else:
continue
keep_idx.append(i)
return src_pts[keep_idx], dst_pts[keep_idx]
def _is_collinear_three(a, b, c, eps):
"""Check if three 2D points are collinear via the cross-product method."""
return (
abs((b[0] - a[0]) * (c[1] - a[1]) - (b[1] - a[1]) * (c[0] - a[0]))
<= eps
)
def _is_collinear_set(points: np.ndarray, eps):
"""Check if a set of 2D points is collinear.
Uses singular value decomposition (SVD) to determine
whether all points lie on a single straight line.
"""
pts = np.array(points)
pts -= pts.mean(axis=0)
_, s, _ = np.linalg.svd(pts)
rank = np.sum(s > eps)
return rank < 2