Source code for movement.io.nwb

"""Helpers to convert between movement poses datasets and NWB files.

The pose tracks in NWB files are formatted according to the ``ndx-pose``
NWB extension, see https://github.com/rly/ndx-pose.
"""

import datetime
from typing import Any

import ndx_pose
import pynwb
import xarray as xr
from attrs import define, field

from movement.utils.logging import logger

ConfigKwargsType = dict[str, Any] | dict[str, dict[str, Any]]


def _safe_dict_field() -> Any:
    """Create a field that defaults to an empty dictionary."""
    return field(
        factory=dict, converter=lambda x: x if isinstance(x, dict) else {}
    )


[docs] @define(kw_only=True) class NWBFileSaveConfig: """Configuration for saving ``movement poses`` dataset to NWBFile(s). This class is used with :func:`movement.io.save_poses.to_nwb_file` to add custom metadata to the NWBFile(s) created from a given ``movement`` dataset. Attributes ---------- nwbfile_kwargs : dict[str, Any] or dict[str, dict[str, Any]], optional Keyword arguments for :class:`pynwb.file.NWBFile`. If ``nwbfile_kwargs`` is a single dictionary, the same keyword arguments will be applied to all NWBFile objects except for ``identifier``. If ``nwbfile_kwargs`` is a dictionary of dictionaries, the outer keys should correspond to individual names in the ``movement`` dataset, and the inner dictionaries will be passed as keyword arguments to the :class:`pynwb.file.NWBFile` constructor. The following arguments cannot be overwritten: - ``subject``: :class:`pynwb.file.Subject` created for the individual using ``subject_kwargs`` The following arguments will have default values if not set: - ``session_description``: "not set" - ``session_start_time``: current UTC time ``identifier`` will be set in the following order of precedence: 1. ``identifier`` in the inner dictionary 2. ``nwbfile_kwargs["identifier"]`` (single-individual dataset only) 3. individual name in the ``movement`` dataset processing_module_kwargs: dict[str, Any] or dict[str, dict[str, Any]], optional Keyword arguments for :class:`pynwb.base.ProcessingModule`. If ``processing_module_kwargs`` is a single dictionary, the same keyword arguments will be applied to all ProcessingModules. If ``processing_module_kwargs`` is a dictionary of dictionaries, the outer keys should correspond to individual names in the ``movement`` dataset, and the inner dictionaries will be passed as keyword arguments to the :class:`pynwb.file.ProcessingModule` constructor. The following arguments will have default values if not set: - ``name``: "behavior" - ``description``: "processed behavioral data" subject_kwargs : dict[str, Any] or dict[str, dict[str, Any]], optional Keyword arguments for :class:`pynwb.file.Subject`. If ``subject_kwargs`` is a single dictionary, the same keyword arguments will be applied to all Subjects except for ``subject_id``. If ``subject_kwargs`` is a dictionary of dictionaries, the outer keys should correspond to individual names in the ``movement`` dataset, and the inner dictionaries will be passed as keyword arguments to the :class:`pynwb.file.Subject` constructor. ``subject_id`` will be set in the following order of precedence: 1. ``subject_id`` in the inner dictionary 2. ``subject_kwargs["subject_id"]`` (single-individual dataset only) 3. individual name in the ``movement`` dataset pose_estimation_series_kwargs : dict[str, Any] or dict[str, dict[str, Any]], optional Keyword arguments for ``ndx_pose.PoseEstimationSeries`` [1]_. If ``pose_estimation_series_kwargs`` is a single dictionary, the same keyword arguments will be applied to all PoseEstimationSeries objects. If ``pose_estimation_series_kwargs`` is a dictionary of dictionaries, the outer keys should correspond to keypoint names in the ``movement`` dataset, and the inner dictionaries will be passed as keyword arguments to the ``ndx_pose.PoseEstimationSeries`` constructor. The following arguments will be set based on the dataset and cannot be overwritten: - ``data``: position data for the keypoint - ``confidence``: confidence data for the keypoint - ``timestamps``: time data for the keypoint The following arguments will have default values if not set: - ``unit``: "pixels" - ``reference_frame``: "(0,0,0) corresponds to ..." ``name`` will be set in the following order of precedence: 1. ``name`` in the inner dictionary 2. ``pose_estimation_series_kwargs["name"]`` (single-keypoint dataset only) 3. keypoint name in the ``movement`` dataset pose_estimation_kwargs : dict[str, Any] or dict[str, dict[str, Any]], optional Keyword arguments for ``ndx_pose.PoseEstimation`` [1]_. If ``pose_estimation_kwargs`` is a single dictionary, the same keyword arguments will be applied to all PoseEstimation objects. If ``pose_estimation_kwargs`` is a dictionary of dictionaries, the outer keys should correspond to individual names in the ``movement`` dataset, and the inner dictionaries will be passed as keyword arguments to the ``ndx_pose.PoseEstimation`` constructor. The following arguments cannot be overwritten: - ``pose_estimation_series``: list of PoseEstimationSeries objects - ``skeleton``: Skeleton object The following arguments will have default values if not set: - ``source_software``: ``source_software`` attribute from the ``movement`` dataset - ``description``: "Estimated positions of <keypoints> for <individual> using <source_software>." If specified, ``name`` will be set in the following order of precedence: 1. ``name`` in the inner dictionary 2. ``pose_estimation_kwargs["name"]`` (single-individual dataset only) 3. individual name in the ``movement`` dataset skeleton_kwargs : dict[str, Any] or dict[str, dict[str, Any]], optional Keyword arguments for ``ndx_pose.Skeleton`` [1]_. If ``skeleton_kwargs`` is a single dictionary, the same keyword arguments will be applied to all Skeleton objects. If ``skeleton_kwargs`` is a dictionary of dictionaries, the outer keys should correspond to individual names in the ``movement`` dataset, and the inner dictionaries will be passed as keyword arguments to the ``ndx_pose.Skeleton`` constructor. The following arguments cannot be overwritten: - ``subject``: :class:`pynwb.file.Subject` created for the individual using ``subject_kwargs`` The following arguments will have default values if not set: - ``name``: "<individual>_skeleton" - ``nodes``: list of keypoint names in the dataset ``name`` will be set in the following order of precedence: 1. ``name`` in the inner dictionary 2. ``skeleton_kwargs["name"]`` (single-individual dataset only) 3. individual name in the ``movement`` dataset References ---------- .. [1] https://github.com/rly/ndx-pose See Also -------- movement.io.save_poses.to_nwb_file Example usage of this class to save a ``movement`` dataset to an NWB file. """ # noqa: E501 nwbfile_kwargs: ConfigKwargsType = _safe_dict_field() processing_module_kwargs: ConfigKwargsType = _safe_dict_field() subject_kwargs: ConfigKwargsType = _safe_dict_field() pose_estimation_series_kwargs: ConfigKwargsType = _safe_dict_field() pose_estimation_kwargs: ConfigKwargsType = _safe_dict_field() skeleton_kwargs: ConfigKwargsType = _safe_dict_field() DEFAULT_NWBFILE_KWARGS = dict(session_description="not set") DEFAULT_PROCESSING_MODULE_KWARGS = dict( name="behavior", description="processed behavioral data" ) DEFAULT_POSE_ESTIMATION_SERIES_KWARGS = dict( reference_frame="(0,0,0) corresponds to ...", unit="pixels" ) def _resolve_kwargs( self, attr_name: str, entity: str | None, entity_type: str, id_key: str, prioritise_entity: bool = False, ) -> dict[str, Any]: """Resolve per-entity (individual/keypoint) or shared kwargs. If the kwargs attribute (retrieved from ``attr_name``) is a dictionary of dictionaries, the outer keys should correspond to individual or keypoint names in the ``movement`` dataset, and the inner dictionaries will be returned as keyword arguments for an individual/keypoint. If the retrieved attribute is a single dictionary, the same arguments will be used for all individuals or keypoints except for ``id_key``, which will be set in the following order of precedence: 1. ``entity`` (if provided) 2. ``kwargs[id_key]`` (if provided) 3. ``DEFAULT_<attr_name>[id_key]`` (class attribute) Parameters ---------- attr_name : str The name of the attribute in the class (e.g. ``nwbfile_kwargs``, ``subject_kwargs``, ``pose_estimation_series_kwargs``) to be resolved. entity : str or None Individual or keypoint name. entity_type : str or None Type of entity (i.e. "individual", "keypoint"). Used in error or warning messages. id_key : str The key in ``cfg`` corresponding to the entity identifier/name (e.g. ``identifier``, ``subject_id``, ``name``) to be set in the returned dictionary. prioritise_entity : bool, optional Flag indicating whether ``entity`` should take precedence over the ``id_key`` when the attribute is a shared config (i.e. a single dictionary). Default is False. Returns ------- dict[str, Any] Keyword arguments for a specific ``entity_type``. """ def infer_entity_or_raise(cfg: dict) -> str: if len(cfg) == 1: inferred = next(iter(cfg)) logger.warning( f"No {entity_type} was provided. Assuming '{inferred}' " f"since there is only one entry in {attr_name}." ) return inferred raise logger.error( ValueError( f"NWBFileSaveConfig has per-{entity_type} " f"{attr_name}, but no {entity_type} was provided." ) ) cfg = getattr(self, attr_name) defaults = getattr(self, f"DEFAULT_{attr_name.upper()}", None) if self._is_per_entity_config(cfg): if entity is None: entity = infer_entity_or_raise(cfg) base = dict(cfg.get(entity, {})) if not base: logger.warning( f"'{entity}' not found in {attr_name}; " f"setting '{entity}' as {id_key}." ) base.setdefault(id_key, entity) else: # Shared config base = dict(cfg) if not prioritise_entity: base[id_key] = cfg.get(id_key, entity) elif entity is not None: base[id_key] = entity if defaults: base = {**defaults, **base} # base overrides defaults return base def _resolve_nwbfile_kwargs( self, individual: str | None = None, prioritise_individual: bool = True ) -> dict[str, Any]: """Resolve the keyword arguments for :class:`pynwb.file.NWBFile`. Parameters ---------- individual : str, optional Individual name. If provided, the method will attempt to retrieve individual-specific settings or fall back to shared or default settings. prioritise_individual: bool, optional Flag indicating whether ``individual`` should take precedence over the ``identifier`` in shared ``nwbfile_kwargs``. Default is True. Returns ------- dict[str, Any] Keyword arguments to be passed to :class:`pynwb.file.NWBFile`. """ kwargs = self._resolve_kwargs( attr_name="nwbfile_kwargs", entity=individual, entity_type="individual", id_key="identifier", prioritise_entity=prioritise_individual, ) if "session_start_time" not in kwargs: logger.warning( "No session_start_time provided in nwbfile_kwargs; " "using current UTC time as default." ) kwargs["session_start_time"] = datetime.datetime.now(datetime.UTC) return kwargs def _resolve_processing_module_kwargs( self, individual: str | None = None ) -> dict[str, Any]: """Resolve the keyword arguments for :class:`pynwb.base.ProcessingModule`. Parameters ---------- individual : str, optional Individual name. If provided, the method will attempt to retrieve individual-specific settings or fall back to shared or default settings. Returns ------- dict[str, Any] Keyword arguments to be passed to :class:`pynwb.base.ProcessingModule`. """ # noqa: E501 kwargs = self._resolve_kwargs( attr_name="processing_module_kwargs", entity=individual, entity_type="individual", id_key="name", ) if kwargs.get("name") in (individual, None): kwargs["name"] = self.DEFAULT_PROCESSING_MODULE_KWARGS["name"] return kwargs def _resolve_subject_kwargs( self, individual: str | None = None, prioritise_individual: bool = True ) -> dict[str, Any]: """Resolve the keyword arguments for :class:`pynwb.file.Subject`. Parameters ---------- individual : str, optional Individual name. If provided, the method will attempt to retrieve individual-specific settings or fall back to shared or default settings. prioritise_individual: bool, optional Flag indicating whether ``individual`` should take precedence over the ``subject_id`` in shared ``subject_kwargs``. Default is True. Returns ------- dict[str, Any] Keyword arguments to be passed to :class:`pynwb.file.Subject`. """ return self._resolve_kwargs( attr_name="subject_kwargs", entity=individual, entity_type="individual", id_key="subject_id", prioritise_entity=prioritise_individual, ) def _resolve_pose_estimation_series_kwargs( self, keypoint: str | None = None, prioritise_keypoint: bool = True ) -> dict[str, Any]: """Resolve the keyword arguments for ``ndx_pose.PoseEstimationSeries``. Parameters ---------- keypoint : str, optional Keypoint name. If provided, the method will attempt to retrieve keypoint-specific settings or fall back to shared or default settings. prioritise_keypoint: bool, optional Flag indicating whether ``keypoint`` should take precedence over the ``name`` in shared ``pose_estimation_series_kwargs``. Default is True. Returns ------- dict[str, Any] Keyword arguments to be passed to ``ndx_pose.PoseEstimationSeries``. """ return self._resolve_kwargs( attr_name="pose_estimation_series_kwargs", entity=keypoint, entity_type="keypoint", id_key="name", prioritise_entity=prioritise_keypoint, ) def _resolve_pose_estimation_kwargs( self, individual: str | None = None, prioritise_individual: bool = True, defaults: dict | None = None, ) -> dict[str, Any]: """Resolve the keyword arguments for ``ndx_pose.PoseEstimation``. Parameters ---------- individual : str, optional Individual name. If provided, the method will attempt to retrieve individual-specific settings or fall back to shared or default settings. prioritise_individual: bool, optional Flag indicating whether ``individual`` should take precedence over the ``name`` in shared ``pose_estimation_kwargs``. Default is True. defaults : dict, optional Dataset-specific default values to be used. Returns ------- dict[str, Any] Keyword arguments to be passed to ``ndx_pose.PoseEstimation``. """ kwargs = self._resolve_kwargs( attr_name="pose_estimation_kwargs", entity=individual, entity_type="individual", id_key="name", prioritise_entity=prioritise_individual, ) if defaults is not None: for key, value in defaults.items(): kwargs.setdefault(key, value) if not self._has_key(self.pose_estimation_kwargs, "name"): kwargs.pop("name", None) # Use ndx_pose default return kwargs def _resolve_skeleton_kwargs( self, individual: str | None = None, prioritise_individual: bool = True, defaults: dict | None = None, ) -> dict[str, Any]: """Resolve the keyword arguments for ``ndx_pose.Skeleton``. Parameters ---------- individual : str, optional Individual name. If provided, the method will attempt to retrieve individual-specific settings or fall back to shared or default settings. prioritise_individual: bool, optional Flag indicating whether ``individual`` should take precedence over the ``name`` in shared ``skeleton_kwargs``. Default is True. defaults : dict, optional Dataset-specific default values to be used. Returns ------- dict[str, Any] Keyword arguments to be passed to ``ndx_pose.Skeleton``. """ kwargs = self._resolve_kwargs( attr_name="skeleton_kwargs", entity=individual, entity_type="individual", id_key="name", prioritise_entity=prioritise_individual, ) if defaults is not None: for key, value in defaults.items(): kwargs.setdefault(key, value) if not self._has_key(self.skeleton_kwargs, "name"): kwargs["name"] = ( f"skeleton{f'_{individual}' if individual else ''}" ) return kwargs @staticmethod def _is_per_entity_config(cfg: ConfigKwargsType) -> bool: return bool(cfg) and all(isinstance(v, dict) for v in cfg.values()) @staticmethod def _has_key(cfg: ConfigKwargsType, key: str) -> bool: if isinstance(cfg, dict): if key in cfg: return True for sub_dict in cfg.values(): if isinstance(sub_dict, dict) and key in sub_dict: return True return False
def _ds_to_pose_and_skeletons( ds: xr.Dataset, config: NWBFileSaveConfig | None = None, subject: pynwb.file.Subject | None = None, from_multi_individual: bool = False, ) -> tuple[ndx_pose.PoseEstimation, ndx_pose.Skeletons]: """Create PoseEstimation and Skeletons objects from a ``movement`` dataset. Parameters ---------- ds : xarray.Dataset A single-individual ``movement`` poses dataset. config : movement.io.nwb.NWBFileSaveConfig Configuration object containing keyword arguments to customise the PoseEstimation and Skeletons objects created from the dataset. If None (default), default values will be used. See :class:`movement.io.nwb.NWBFileSaveConfig` for more details. subject : pynwb.file.Subject, optional Subject object to be linked in the Skeleton object. from_multi_individual : bool, optional Flag indicating whether ``ds`` originates from a multi-individual dataset. Passed to the ``NWBFileSaveConfig`` methods to determine whether to prioritise individual names in the dataset over ``name`` in shared ``pose_estimation_kwargs`` and ``skeleton_kwargs``. Default is False. Returns ------- pose_estimation : ndx_pose.PoseEstimation PoseEstimation object containing PoseEstimationSeries objects for each keypoint in the dataset. skeletons : ndx_pose.Skeletons Skeletons object containing all Skeleton objects. """ if ds.individuals.size != 1: raise logger.error( ValueError( "Dataset must contain only one individual to create " "PoseEstimation and Skeletons objects." ) ) config = config or NWBFileSaveConfig() individual = ds.individuals.values.item() keypoints = ds.keypoints.values.tolist() # Convert timestamps to seconds if necessary timestamps = ( ds.time.values if ds.time_unit == "seconds" else ds.time.values / getattr(ds, "fps", 1.0) ) pose_estimation_series = [ ndx_pose.PoseEstimationSeries( data=ds.sel(keypoints=keypoint).position.values, confidence=ds.sel(keypoints=keypoint).confidence.values, timestamps=timestamps, **( config._resolve_pose_estimation_series_kwargs( keypoint, len(keypoints) > 1 ) ), ) for keypoint in keypoints ] skeleton_list = [ ndx_pose.Skeleton( subject=subject, **config._resolve_skeleton_kwargs( individual, from_multi_individual, {"nodes": keypoints} ), ) ] skeletons = ndx_pose.Skeletons(skeletons=skeleton_list) # Group all PoseEstimationSeries into a PoseEstimation object description = ( f"Estimated positions of {', '.join(keypoints)} for " f"{individual} using {ds.source_software}." ) pose_estimation = ndx_pose.PoseEstimation( pose_estimation_series=pose_estimation_series, skeleton=skeleton_list[-1], **config._resolve_pose_estimation_kwargs( individual, from_multi_individual, { "description": description, "source_software": ds.source_software, }, ), ) return pose_estimation, skeletons def _write_processing_module( nwb_file: pynwb.file.NWBFile, processing_module_kwargs: dict[str, Any], pose_estimation: ndx_pose.PoseEstimation, skeletons: ndx_pose.Skeletons, ) -> None: """Write behaviour processing data to an NWB file. PoseEstimation or Skeletons objects will be written to the specified ProcessingModule in the NWB file, formatted according to the ``ndx-pose`` NWB extension. If the module does not exist, it will be created. Existing objects in the NWB file will not be overwritten. Parameters ---------- nwb_file : pynwb.file.NWBFile The NWBFile object to which the data will be added. processing_module_kwargs : dict[str, Any] Keyword arguments for the :class:`pynwb.base.ProcessingModule` in the NWB file. The ``name`` key will be used to determine the ProcessingModule to which the data will be added. If the ProcessingModule does not exist, it will be created with these keyword arguments. pose_estimation : ndx_pose.PoseEstimation PoseEstimation object containing the pose data for an individual. skeletons : ndx_pose.Skeletons Skeletons object containing the skeleton data for an individual. """ def add_to_processing_module(obj, obj_name: str): try: processing_module.add(obj) logger.debug(f"Added {obj_name} object to NWB file.") except ValueError: logger.warning(f"{obj_name} object already exists. Skipping...") processing_module_name = processing_module_kwargs.get("name") processing_module = nwb_file.processing.get(processing_module_name) if processing_module is None: processing_module = nwb_file.create_processing_module( **processing_module_kwargs ) logger.debug( f"Created {processing_module_name} processing module in NWB file." ) else: logger.debug( f"Using existing {processing_module_name} processing module." ) add_to_processing_module(skeletons, "Skeletons") add_to_processing_module(pose_estimation, "PoseEstimation")