Source code for movement.move_accessor

"""Accessor for extending :class:`xarray.Dataset` objects."""

import logging
from typing import ClassVar

import xarray as xr

from movement import filtering
from movement.analysis import kinematics
from movement.utils.logging import log_error
from movement.validators.datasets import ValidBboxesDataset, ValidPosesDataset

logger = logging.getLogger(__name__)

# Preserve the attributes (metadata) of xarray objects after operations
xr.set_options(keep_attrs=True)


[docs] @xr.register_dataset_accessor("move") class MovementDataset: """An :class:`xarray.Dataset` accessor for ``movement`` data. A ``movement`` dataset is an :class:`xarray.Dataset` with a specific structure to represent pose tracks or bounding boxes data, associated confidence scores and relevant metadata. Methods/properties that extend the standard ``xarray`` functionality are defined in this class. To avoid conflicts with ``xarray``'s namespace, ``movement``-specific methods are accessed using the ``move`` keyword, for example ``ds.move.validate()`` (see [1]_ for more details). Attributes ---------- dim_names : dict A dictionary with the names of the expected dimensions in the dataset, for each dataset type (``"poses"`` or ``"bboxes"``). var_names : dict A dictionary with the expected data variables in the dataset, for each dataset type (``"poses"`` or ``"bboxes"``). References ---------- .. [1] https://docs.xarray.dev/en/stable/internals/extending-xarray.html """ # Set class attributes for expected dimensions and data variables dim_names: ClassVar[dict] = { "poses": ("time", "individuals", "keypoints", "space"), "bboxes": ("time", "individuals", "space"), } var_names: ClassVar[dict] = { "poses": ("position", "confidence"), "bboxes": ("position", "shape", "confidence"), } def __init__(self, ds: xr.Dataset): """Initialize the MovementDataset.""" self._obj = ds # Set instance attributes based on dataset type self.dim_names_instance = self.dim_names[self._obj.ds_type] self.var_names_instance = self.var_names[self._obj.ds_type] def __getattr__(self, name: str) -> xr.DataArray: """Forward requested but undefined attributes to relevant modules. This method currently only forwards kinematic property computation and filtering operations to the respective functions in :mod:`movement.analysis.kinematics` and :mod:`movement.filtering`. Parameters ---------- name : str The name of the attribute to get. Returns ------- xarray.DataArray The computed attribute value. Raises ------ AttributeError If the attribute does not exist. """ def method(*args, **kwargs): if hasattr(kinematics, name): return self.kinematics_wrapper(name, *args, **kwargs) elif hasattr(filtering, name): return self.filtering_wrapper(name, *args, **kwargs) else: error_msg = ( f"'{self.__class__.__name__}' object has " f"no attribute '{name}'" ) raise log_error(AttributeError, error_msg) return method
[docs] def kinematics_wrapper( self, fn_name: str, *args, **kwargs ) -> xr.DataArray: """Provide convenience method for computing kinematic properties. This method forwards kinematic property computation to the respective functions in :mod:`movement.analysis.kinematics`. Parameters ---------- fn_name : str The name of the kinematics function to call. args : tuple Positional arguments to pass to the function. kwargs : dict Keyword arguments to pass to the function. Returns ------- xarray.DataArray The computed kinematics attribute value. Raises ------ RuntimeError If the requested function fails to execute. Examples -------- Compute ``displacement`` based on the ``position`` data variable in the Dataset ``ds`` and store the result in ``ds``. >>> ds["displacement"] = ds.move.compute_displacement() Compute ``velocity`` based on the ``position`` data variable in the Dataset ``ds`` and store the result in ``ds``. >>> ds["velocity"] = ds.move.compute_velocity() Compute ``acceleration`` based on the ``position`` data variable in the Dataset ``ds`` and store the result in ``ds``. >>> ds["acceleration"] = ds.move.compute_acceleration() """ try: return getattr(kinematics, fn_name)( self._obj.position, *args, **kwargs ) except Exception as e: error_msg = ( f"Failed to evoke '{fn_name}' via 'move' accessor. {str(e)}" ) raise log_error(RuntimeError, error_msg) from e
[docs] def filtering_wrapper( self, fn_name: str, *args, data_vars: list[str] | None = None, **kwargs ) -> xr.DataArray | dict[str, xr.DataArray]: """Provide convenience method for filtering data variables. This method forwards filtering and/or smoothing to the respective functions in :mod:`movement.filtering`. The data variables to filter can be specified in ``data_vars``. If ``data_vars`` is not specified, the ``position`` data variable is selected by default. Parameters ---------- fn_name : str The name of the filtering function to call. args : tuple Positional arguments to pass to the function. data_vars : list[str] | None The data variables to apply filtering. If ``None``, the ``position`` data variable will be passed by default. kwargs : dict Keyword arguments to pass to the function. Returns ------- xarray.DataArray | dict[str, xarray.DataArray] The filtered data variable or a dictionary of filtered data variables, if multiple data variables are specified. Raises ------ RuntimeError If the requested function fails to execute. Examples -------- Filter the ``position`` data variable to drop points with ``confidence`` below 0.7 and store the result back into the Dataset ``ds``. Since ``data_vars`` is not supplied, the filter will be applied to the ``position`` data variable by default. >>> ds["position"] = ds.move.filter_by_confidence(threshold=0.7) Apply a median filter to the ``position`` data variable and store this back into the Dataset ``ds``. >>> ds["position"] = ds.move.median_filter(window=3) Apply a Savitzky-Golay filter to both the ``position`` and ``velocity`` data variables and store these back into the Dataset ``ds``. ``filtered_data`` is a dictionary, where the keys are the data variable names and the values are the filtered DataArrays. >>> filtered_data = ds.move.savgol_filter( ... window=3, data_vars=["position", "velocity"] ... ) >>> ds.update(filtered_data) """ ds = self._obj if data_vars is None: # Default to filter on position data_vars = ["position"] if fn_name == "filter_by_confidence": # Add confidence to kwargs kwargs["confidence"] = ds.confidence try: result = { data_var: getattr(filtering, fn_name)( ds[data_var], *args, **kwargs ) for data_var in data_vars } # Return DataArray if result only has one key if len(result) == 1: return result[list(result.keys())[0]] return result except Exception as e: error_msg = ( f"Failed to evoke '{fn_name}' via 'move' accessor. {str(e)}" ) raise log_error(RuntimeError, error_msg) from e
[docs] def validate(self) -> None: """Validate the dataset. This method checks if the dataset contains the expected dimensions, data variables, and metadata attributes. It also ensures that the dataset contains valid poses or bounding boxes data. Raises ------ ValueError If the dataset is missing required dimensions, data variables, or contains invalid poses or bounding boxes data. """ fps = self._obj.attrs.get("fps", None) source_software = self._obj.attrs.get("source_software", None) try: self._validate_dims() self._validate_data_vars() if self._obj.ds_type == "poses": ValidPosesDataset( position_array=self._obj["position"].values, confidence_array=self._obj["confidence"].values, individual_names=self._obj.coords["individuals"].values, keypoint_names=self._obj.coords["keypoints"].values, fps=fps, source_software=source_software, ) elif self._obj.ds_type == "bboxes": # Define frame_array. # Recover from time axis in seconds if necessary. frame_array = self._obj.coords["time"].values.reshape(-1, 1) if self._obj.attrs["time_unit"] == "seconds": frame_array *= fps ValidBboxesDataset( position_array=self._obj["position"].values, shape_array=self._obj["shape"].values, confidence_array=self._obj["confidence"].values, individual_names=self._obj.coords["individuals"].values, frame_array=frame_array, fps=fps, source_software=source_software, ) except Exception as e: error_msg = ( f"The dataset does not contain valid {self._obj.ds_type}. {e}" ) raise log_error(ValueError, error_msg) from e
def _validate_dims(self) -> None: missing_dims = set(self.dim_names_instance) - set(self._obj.dims) if missing_dims: raise ValueError( f"Missing required dimensions: {sorted(missing_dims)}" ) # sort for a reproducible error message def _validate_data_vars(self) -> None: missing_vars = set(self.var_names_instance) - set(self._obj.data_vars) if missing_vars: raise ValueError( f"Missing required data variables: {sorted(missing_vars)}" ) # sort for a reproducible error message