Source code for movement.sample_data

"""Module for fetching and loading sample datasets.

This module provides functions for fetching and loading sample data used in
tests, examples, and tutorials. The data are stored in a remote repository
on GIN and are downloaded to the user's local machine the first time they
are used.
"""

import logging
from pathlib import Path

import pooch
import xarray
import yaml
from requests.exceptions import RequestException

from movement.io import load_poses
from movement.logging import log_error, log_warning

logger = logging.getLogger(__name__)

# URL to the remote data repository on GIN
# noinspection PyInterpreter
DATA_URL = (
    "https://gin.g-node.org/neuroinformatics/movement-test-data/raw/master"
)

# Save data in ~/.movement/data
DATA_DIR = Path("~", ".movement", "data").expanduser()
# Create the folder if it doesn't exist
DATA_DIR.mkdir(parents=True, exist_ok=True)

# File name for the .yaml file in DATA_URL containing dataset metadata
METADATA_FILE = "metadata.yaml"


def _download_metadata_file(file_name: str, data_dir: Path = DATA_DIR) -> Path:
    """Download the metadata yaml file.

    This function downloads the yaml file containing sample metadata from
    the ``movement`` data repository and saves it in the specified directory
    with a temporary filename - temp_{file_name} - to avoid overwriting any
    existing files.

    Parameters
    ----------
    file_name : str
        Name of the metadata file to fetch.
    data_dir : pathlib.Path, optional
        Directory to store the metadata file in. Defaults to the constant
        ``DATA_DIR``. Can be overridden for testing purposes.

    Returns
    -------
    path : pathlib.Path
        Path to the downloaded file.

    """
    local_file_path = pooch.retrieve(
        url=f"{DATA_URL}/{file_name}",
        known_hash=None,
        path=data_dir,
        fname=f"temp_{file_name}",
        progressbar=False,
    )
    logger.debug(
        f"Successfully downloaded sample metadata file {file_name} "
        f"from {DATA_URL} to {data_dir}"
    )
    return Path(local_file_path)


def _fetch_metadata(
    file_name: str, data_dir: Path = DATA_DIR
) -> dict[str, dict]:
    """Download the metadata yaml file and load it as a dictionary.

    Parameters
    ----------
    file_name : str
        Name of the metadata file to fetch.
    data_dir : pathlib.Path, optional
        Directory to store the metadata file in. Defaults to
        the constant ``DATA_DIR``. Can be overridden for testing purposes.

    Returns
    -------
    dict
        A dictionary containing metadata for each sample dataset, with the
        dataset name (pose file name) as the key.

    """
    local_file_path = Path(data_dir / file_name)
    failed_msg = "Failed to download the newest sample metadata file."

    # try downloading the newest metadata file
    try:
        downloaded_file_path = _download_metadata_file(file_name, data_dir)
        # if download succeeds, replace any existing local metadata file
        downloaded_file_path.replace(local_file_path)
    # if download fails, try loading an existing local metadata file,
    # otherwise raise an error
    except RequestException as exc_info:
        if local_file_path.is_file():
            log_warning(
                f"{failed_msg} Will use the existing local version instead."
            )
        else:
            raise log_error(RequestException, failed_msg) from exc_info

    with open(local_file_path) as metadata_file:
        metadata = yaml.safe_load(metadata_file)
    return metadata


def _generate_file_registry(metadata: dict[str, dict]) -> dict[str, str]:
    """Generate a file registry based on the contents of the metadata.

    This includes files containing poses, frames, or entire videos.

    Parameters
    ----------
    metadata : dict
        List of dictionaries containing metadata for each sample dataset.

    Returns
    -------
    dict
        Dictionary mapping file paths to their SHA-256 checksums.

    """
    file_registry = {}
    for ds, val in metadata.items():
        file_registry[f"poses/{ds}"] = val["sha256sum"]
        for key in ["video", "frame"]:
            file_name = val[key]["file_name"]
            if file_name:
                file_registry[f"{key}s/{file_name}"] = val[key]["sha256sum"]
    return file_registry


# Create a download manager for the pose data
metadata = _fetch_metadata(METADATA_FILE, DATA_DIR)
file_registry = _generate_file_registry(metadata)
SAMPLE_DATA = pooch.create(
    path=DATA_DIR,
    base_url=f"{DATA_URL}/",
    retry_if_failed=0,
    registry=file_registry,
)


[docs] def list_datasets() -> list[str]: """Find available sample datasets. Returns ------- filenames : list of str List of filenames for available pose data. """ return list(metadata.keys())
[docs] def fetch_dataset_paths(filename: str) -> dict: """Get paths to sample pose data and any associated frames or videos. The data are downloaded from the ``movement`` data repository to the user's local machine upon first use and are stored in a local cache directory. The function returns the paths to the downloaded files. Parameters ---------- filename : str Name of the pose file to fetch. Returns ------- paths : dict Dictionary mapping file types to their respective paths. The possible file types are: "poses", "frame", "video". If "frame" or "video" are not available, the corresponding value is None. Examples -------- >>> from movement.sample_data import fetch_dataset_paths >>> paths = fetch_dataset_paths("DLC_single-mouse_EPM.predictions.h5") >>> poses_path = paths["poses"] >>> frame_path = paths["frame"] >>> video_path = paths["video"] See Also -------- fetch_dataset """ available_pose_files = list_datasets() if filename not in available_pose_files: raise log_error( ValueError, f"File '{filename}' is not in the registry. " f"Valid filenames are: {available_pose_files}", ) frame_file_name = metadata[filename]["frame"]["file_name"] video_file_name = metadata[filename]["video"]["file_name"] return { "poses": Path( SAMPLE_DATA.fetch(f"poses/{filename}", progressbar=True) ), "frame": None if not frame_file_name else Path( SAMPLE_DATA.fetch(f"frames/{frame_file_name}", progressbar=True) ), "video": None if not video_file_name else Path( SAMPLE_DATA.fetch(f"videos/{video_file_name}", progressbar=True) ), }
[docs] def fetch_dataset( filename: str, ) -> xarray.Dataset: """Load a sample dataset containing pose data. The data are downloaded from the ``movement`` data repository to the user's local machine upon first use and are stored in a local cache directory. This function returns the pose data as an xarray Dataset. If there are any associated frames or videos, these files are also downloaded and the paths are stored as dataset attributes. Parameters ---------- filename : str Name of the file to fetch. Returns ------- ds : xarray.Dataset Pose data contained in the fetched sample file. Examples -------- >>> from movement.sample_data import fetch_dataset >>> ds = fetch_dataset("DLC_single-mouse_EPM.predictions.h5") >>> frame_path = ds.video_path >>> video_path = ds.frame_path See Also -------- fetch_dataset_paths """ file_paths = fetch_dataset_paths(filename) ds = load_poses.from_file( file_paths["poses"], source_software=metadata[filename]["source_software"], fps=metadata[filename]["fps"], ) ds.attrs["frame_path"] = file_paths["frame"] ds.attrs["video_path"] = file_paths["video"] return ds