Source code for movement.io.load

"""Load data from various frameworks into ``movement``."""

from collections.abc import Callable
from functools import wraps
from pathlib import Path
from typing import (
    Concatenate,
    Literal,
    ParamSpec,
    Protocol,
    TypeAlias,
    TypeVar,
    cast,
)

import attrs
import pynwb
import xarray as xr

from movement.utils.logging import logger
from movement.validators.files import ValidFile

TInputFile = TypeVar("TInputFile", Path, str, pynwb.file.NWBFile)
P = ParamSpec("P")
SourceSoftware: TypeAlias = Literal[
    "DeepLabCut",
    "SLEAP",
    "LightningPose",
    "Anipose",
    "NWB",
    "VIA-tracks",
]


[docs] class LoaderProtocol(Protocol): """Protocol for loader functions to be registered via ``register_loader``. All loader functions registered via :func:`register_loader` must conform to this protocol. Loaders must accept a file path (str or Path) or :class:`pynwb.file.NWBFile` object) as their first argument and return an :class:`xarray.Dataset` containing pose tracks or bounding box tracks. Additional positional and keyword arguments are allowed. See Also -------- register_loader : Decorator for registering loader functions. """ def __call__( self, file: Path | str | pynwb.file.NWBFile, *args, **kwargs ) -> xr.Dataset: """Load data from a file. Parameters ---------- file Path to the file or a :class:`pynwb.file.NWBFile` object. *args Additional positional arguments for the loader. **kwargs Additional keyword arguments for the loader. Returns ------- xarray.Dataset The loaded dataset. """ ...
_LOADER_REGISTRY: dict[str, LoaderProtocol] = {} def _get_validator_kwargs( validator_cls: type[ValidFile], *, loader_kwargs: dict, ) -> dict: """Extract the relevant kwargs for a given validator class.""" # Only extract fields that are used in the validator's __init__ validator_fields = { field.name for field in attrs.fields(validator_cls) if field.init } return { field_name: loader_kwargs[field_name] for field_name in validator_fields if field_name in loader_kwargs } def _build_suffix_map( validators_list: list[type[ValidFile]], ) -> dict[str, type[ValidFile]]: """Build a mapping of file suffixes to validator classes.""" suffix_map: dict[str, type[ValidFile]] = {} for validator_cls in validators_list: for suffix in getattr(validator_cls, "suffixes", set()): suffix_map[suffix] = validator_cls return suffix_map def _validate_file( file: TInputFile, suffix_map: dict[str, type[ValidFile]], source_software: SourceSoftware, loader_kwargs: dict | None = None, ) -> ValidFile: """Validate the input file using the appropriate validator. Parameters ---------- file The file path or NWBFile object to validate. suffix_map Mapping of file suffixes to validator classes. source_software The source software name (for error messages). loader_kwargs Additional arguments from the loader function to pass to the validator. Returns ------- ValidFile A validated file instance. Raises ------ ValueError If the file format is not supported. """ if isinstance(file, pynwb.file.NWBFile): file_suffix = ".nwb" else: file_suffix = Path(file).suffix validator_cls = suffix_map.get(file_suffix) if validator_cls is None: raise logger.error( ValueError( f"Unsupported format for '{source_software}': {file_suffix}." ) ) validator_kwargs = _get_validator_kwargs( validator_cls, loader_kwargs=loader_kwargs or {} ) return validator_cls( file=file, **validator_kwargs, # type: ignore[call-arg] )
[docs] def register_loader( source_software: SourceSoftware, *, file_validators: type[ValidFile] | list[type[ValidFile]] | None = None, ) -> Callable[ [Callable[Concatenate[TInputFile, P], xr.Dataset]], Callable[Concatenate[TInputFile, P], xr.Dataset], ]: """Register a loader function for a given source software. The decorator also handles file validation using any provided file validator(s). Parameters ---------- source_software The name of the source software. file_validators File validator(s) to validate the input file path and content. Returns ------- Callable A decorator that registers the loader function. Notes ----- If file validators are provided, the ``file`` argument passed to the decorated loader function will be an instance of the appropriate :class:`movement.validators.files.ValidFile` subclass, instead of the original file path or :class:`pynwb.file.NWBFile` object. Examples -------- >>> from movement.io.load import register_loader >>> from movement.validators.files import ( ... ValidDeepLabCutH5, ... ValidDeepLabCutCSV, ... ) >>> @register_loader( ... "DeepLabCut", ... file_validators=[ValidDeepLabCutH5, ValidDeepLabCutCSV], ... ) ... def from_dlc_file(file: str | Path, fps=None, **kwargs): ... pass """ validators_list: list[type[ValidFile]] = ( [file_validators] if file_validators is not None and not isinstance(file_validators, list) else file_validators or [] ) # Map suffixes to validator classes suffix_map = _build_suffix_map(validators_list) def decorator( loader_fn: Callable[Concatenate[TInputFile, P], xr.Dataset], ) -> Callable[Concatenate[TInputFile, P], xr.Dataset]: @wraps(loader_fn) def wrapper(file: TInputFile, *args, **kwargs) -> xr.Dataset: if not validators_list: return loader_fn(file, *args, **kwargs) valid_file = _validate_file( file, suffix_map, source_software, kwargs ) return loader_fn(valid_file, *args, **kwargs) # type: ignore[arg-type] # Register the loader in the global registry _LOADER_REGISTRY[source_software] = cast("LoaderProtocol", wrapper) return wrapper return decorator
[docs] def load_dataset( file: Path | str | pynwb.file.NWBFile, source_software: SourceSoftware, fps: float | None = None, **kwargs, ) -> xr.Dataset: """Create a ``movement`` dataset from any supported third-party file. Parameters ---------- file Path to the file containing predicted poses or tracked bounding boxes. If ``source software`` is "NWB", this can also be a :class:`pynwb.file.NWBFile` object. The file format must be among those supported by the :mod:`movement.io.load_poses` or :mod:`movement.io.load_bboxes` modules. Based on the value of ``source_software``, the appropriate loading function will be called. source_software The source software of the file. fps The number of frames per second in the video. If None (default), the ``time`` coordinates will be in frame numbers. This argument is ignored when ``source_software`` is "NWB", as the frame rate will be directly read or estimated from metadata in the NWB file. **kwargs Additional keyword arguments to pass to the software-specific loading functions in modules listed under "See Also". Returns ------- xarray.Dataset ``movement`` dataset containing the pose or bounding box tracks, confidence scores, and associated metadata. See Also -------- movement.io.load_poses movement.io.load_bboxes Examples -------- >>> from movement.io import load_dataset >>> ds = load_dataset( ... "path/to/file.h5", source_software="DeepLabCut", fps=30 ... ) """ if source_software not in _LOADER_REGISTRY: raise logger.error( ValueError(f"Unsupported source software: {source_software}") ) if source_software == "NWB": if fps is not None: logger.warning( "The fps argument is ignored when loading from an NWB file. " "The frame rate will be directly read or estimated from " "metadata in the file." ) return _LOADER_REGISTRY[source_software](file, **kwargs) return _LOADER_REGISTRY[source_software](file, fps, **kwargs)
[docs] def load_multiview_dataset( file_dict: dict[str, Path | str], source_software: SourceSoftware, fps: float | None = None, **kwargs, ) -> xr.Dataset: """Load and merge data from multiple files representing different views. Parameters ---------- file_dict A dict whose keys are the view names and values are the paths to load. source_software The source software of the file. fps The number of frames per second in the video. If None (default), the ``time`` coordinates will be in frame numbers. This argument is ignored when ``source_software`` is "NWB", as the frame rate will be directly read or estimated from metadata in the NWB file. **kwargs Additional keyword arguments to pass to the software-specific loading functions in modules listed under "See Also". Returns ------- xarray.Dataset ``movement`` dataset containing data concatenated along a new ``view`` dimension. Notes ----- The attributes of the resulting dataset will be taken from the first dataset specified in ``file_path_dict``. This is the default behaviour of :func:`xarray.concat` used under the hood. See Also -------- movement.io.load_poses movement.io.load_bboxes """ views_list = list(file_dict.keys()) new_coord_views = xr.DataArray(views_list, dims="view") dataset_list = [ load_dataset(f, source_software=source_software, fps=fps, **kwargs) for f in file_dict.values() ] return xr.concat(dataset_list, dim=new_coord_views)