Source code for movement.utils.vector

"""Utility functions for vector operations."""

import numpy as np
import xarray as xr

from movement.logging import log_error


[docs] def cart2pol(data: xr.DataArray) -> xr.DataArray: """Transform Cartesian coordinates to polar. Parameters ---------- data : xarray.DataArray The input data containing ``space`` as a dimension, with ``x`` and ``y`` in the dimension coordinate. Returns ------- xarray.DataArray An xarray DataArray containing the polar coordinates stored in the ``space_pol`` dimension, with ``rho`` and ``phi`` in the dimension coordinate. The angles ``phi`` returned are in radians, in the range ``[-pi, pi]``. """ _validate_dimension_coordinates(data, {"space": ["x", "y"]}) rho = xr.apply_ufunc( np.linalg.norm, data, input_core_dims=[["space"]], kwargs={"axis": -1}, ) phi = xr.apply_ufunc( np.arctan2, data.sel(space="y"), data.sel(space="x"), ) # Replace space dim with space_pol dims = list(data.dims) dims[dims.index("space")] = "space_pol" return xr.combine_nested( [ rho.assign_coords({"space_pol": "rho"}), phi.assign_coords({"space_pol": "phi"}), ], concat_dim="space_pol", ).transpose(*dims)
[docs] def pol2cart(data: xr.DataArray) -> xr.DataArray: """Transform polar coordinates to Cartesian. Parameters ---------- data : xarray.DataArray The input data containing ``space_pol`` as a dimension, with ``rho`` and ``phi`` in the dimension coordinate. Returns ------- xarray.DataArray An xarray DataArray containing the Cartesian coordinates stored in the ``space`` dimension, with ``x`` and ``y`` in the dimension coordinate. """ _validate_dimension_coordinates(data, {"space_pol": ["rho", "phi"]}) rho = data.sel(space_pol="rho") phi = data.sel(space_pol="phi") x = rho * np.cos(phi) y = rho * np.sin(phi) # Replace space_pol dim with space dims = list(data.dims) dims[dims.index("space_pol")] = "space" return xr.combine_nested( [ x.assign_coords({"space": "x"}), y.assign_coords({"space": "y"}), ], concat_dim="space", ).transpose(*dims)
def _validate_dimension_coordinates( data: xr.DataArray, required_dim_coords: dict ) -> None: """Validate the input data array. Ensure that it contains the required dimensions and coordinates. Parameters ---------- data : xarray.DataArray The input data to validate. required_dim_coords : dict A dictionary of required dimensions and their corresponding coordinate values. Raises ------ ValueError If the input data does not contain the required dimension(s) and/or the required coordinate(s). """ missing_dims = [dim for dim in required_dim_coords if dim not in data.dims] error_message = "" if missing_dims: error_message += ( f"Input data must contain {missing_dims} as dimensions.\n" ) missing_coords = [] for dim, coords in required_dim_coords.items(): missing_coords = [ coord for coord in coords if coord not in data.coords.get(dim, []) ] if missing_coords: error_message += ( "Input data must contain " f"{missing_coords} in the '{dim}' coordinates." ) if error_message: raise log_error(ValueError, error_message)