r"""Broadcasting operations across ``xarray.DataArray`` dimensions.
This module essentially provides an equivalent functionality to
:func:`numpy.apply_along_axis`, but for ``xarray.DataArray`` objects.
This functionality is provided as a decorator, so it can be applied to both
functions within the package and be available to users who would like to use it
in their analysis.
In essence; suppose that we have a function which takes a 1D-slice of a
``xarray.DataArray`` and returns either a scalar value, or another 1D array.
Typically, one would either have to call this function successively in a
``for`` loop, looping over all the 1D slices in a ``xarray.DataArray`` that
need to be examined, or re-write the function to be able to broadcast along the
necessary dimension of the data structure.
The :func:`make_broadcastable()\
<movement.utils.broadcasting.make_broadcastable>` decorator takes care of the
latter piece of work, allowing us to write functions that operate on 1D slices,
then apply this decorator to have them work across ``xarray.DataArray``
dimensions. The function
>>> def my_function(input_1d, *args, **kwargs):
... # do something
... return scalar_or_1d_output
which previously only worked with 1D-slices can be decorated
>>> @make_broadcastable()
... def my_function(input_1d, *args, **kwargs):
... # do something
... return scalar_or_1d_output
effectively changing its call signature to
>>> def my_function(data_array, *args, dimension, **kwargs):
... # do my_function, but do it to all the slices
... # along the dimension of data_array.
... return data_array_output
which will perform the action of ``my_function`` along the ``dimension`` given.
The ``*args`` and ``**kwargs`` retain their original interpretations from
``my_function`` too.
"""
from collections.abc import Callable
from functools import wraps
from typing import (
Concatenate,
Literal,
ParamSpec,
TypeAlias,
TypeVar,
overload,
)
import numpy as np
import xarray as xr
from numpy.typing import ArrayLike
ScalarOr1D = TypeVar("ScalarOr1D", float, int, bool, ArrayLike)
Self = TypeVar("Self")
P = ParamSpec("P")
ClsMethod1DTo1D: TypeAlias = Callable[
Concatenate[Self, ArrayLike, P], ScalarOr1D
]
ClsMethodDaToDa: TypeAlias = Callable[
Concatenate[Self, xr.DataArray, P], xr.DataArray
]
Function1DTo1D: TypeAlias = Callable[Concatenate[ArrayLike, P], ScalarOr1D]
FunctionDaToDa: TypeAlias = Callable[
Concatenate[xr.DataArray, P], xr.DataArray
]
[docs]
def apply_along_da_axis(
f: Callable[[ArrayLike], ScalarOr1D],
data: xr.DataArray,
dimension: str,
new_dimension_name: str | None = None,
) -> xr.DataArray:
"""Apply a function ``f`` across ``dimension`` of ``data``.
``f`` should be callable as ``f(input_1d)`` where ``input_1d`` is a
one-dimensional :data:`numpy.typing.ArrayLike` object. It should return
either a scalar or one-dimensional :data:`numpy.typing.ArrayLike` object.
Parameters
----------
f : Callable
Function that takes 1D inputs and returns either scalar or 1D outputs.
This will be cast across the ``dimension`` of the ``data``.
data: xarray.DataArray
Values to be cast over.
dimension : str
Dimension of ``data`` to broadcast ``f`` across.
new_dimension_name : str, optional
If ``f`` returns non-scalar values, the dimension in the output that
these values are returned along is given the name
``new_dimension_name``. Defaults to ``"result"``.
Returns
-------
xarray.DataArray
Result of broadcasting ``f`` along the ``dimension`` of ``data``.
- If ``f`` returns a scalar or ``(1,)``-shaped output, the output has
one fewer dimension than ``data``, with ``dimension`` being dropped.
All other dimensions retain their names and sizes.
- If ``f`` returns a ``(n,)``-shaped output for ``n > 1``; all non-
``dimension`` dimensions of ``data`` retain their shapes. The
``dimension`` dimension itself is replaced with a new dimension,
``new_dimension_name``, containing the output of the application of
``f``.
"""
output: xr.DataArray = xr.apply_ufunc(
lambda input_1d: np.atleast_1d(f(input_1d)),
data,
input_core_dims=[[dimension]],
exclude_dims=set((dimension,)),
output_core_dims=[[dimension]],
vectorize=True,
)
if len(output[dimension]) < 2:
output = output.squeeze(dim=dimension)
else:
# Rename the non-1D output dimension according to request
output = output.rename(
{dimension: new_dimension_name if new_dimension_name else "result"}
)
return output
@overload
def make_broadcastable(
*,
is_classmethod: Literal[False],
only_broadcastable_along: str | None = None,
new_dimension_name: str | None = None,
) -> Callable[[Function1DTo1D[P, ScalarOr1D]], FunctionDaToDa[P]]: ...
@overload
def make_broadcastable(
*,
is_classmethod: Literal[True],
only_broadcastable_along: str | None = None,
new_dimension_name: str | None = None,
) -> Callable[
[ClsMethod1DTo1D[Self, P, ScalarOr1D]], ClsMethodDaToDa[Self, P]
]: ...
@overload
def make_broadcastable(
*,
is_classmethod: bool,
only_broadcastable_along: str | None = None,
new_dimension_name: str | None = None,
) -> Callable[
..., FunctionDaToDa[P] | ClsMethodDaToDa[Self, P]
]: ... # Fallback when `is_classmethod` is unknown at type-check time
[docs]
def make_broadcastable(
*,
is_classmethod: bool = False,
only_broadcastable_along: str | None = None,
new_dimension_name: str | None = None,
):
"""Create a decorator that allows a function to be broadcast.
Parameters
----------
is_classmethod : bool
Whether the target of the decoration is a class method which takes
the ``self`` argument, or a standalone function that receives no
implicit arguments.
only_broadcastable_along : str, optional
Whether the decorated function should only support broadcasting along
this dimension. The returned function will not take the
``broadcast_dimension`` argument, and will use the dimension provided
here as the value for this argument.
new_dimension_name : str, optional
Passed to :func:`apply_along_da_axis`.
Returns
-------
Callable
Decorator function that can be applied with the
``@make_broadcastable(...)`` syntax. See Notes for a description of
the action of the returned decorator.
Notes
-----
The returned decorator (the "``r_decorator``") extends a function that
acts on a 1D sequence of values, allowing it to be broadcast along the
axes of an input ``xarray.DataArray``.
The ``r_decorator`` takes a single parameter, ``f``. ``f`` should be a
``Callable`` that acts on 1D inputs, that is to be converted into a
broadcast-able function ``fr``, applying the action of ``f`` along an axis
of an ``xarray.DataArray``. ``f`` should return either scalar or 1D
outputs.
If ``f`` is a class method, it should be callable as
``f(self, [x, y, ...], *args, **kwargs)``.
Otherwise, ``f`` should be callable as
``f([x, y, ...], *args, **kwargs)``.
The function ``fr`` returned by the ``r_decorator`` is callable with the
signature
``fr([self,] data, *args, broadcast_dimension = str, **kwargs)``,
where the ``self`` argument is present only if ``f`` was a class method.
``fr`` applies ``f`` along the ``broadcast_dimension`` of ``data``.
The ``*args`` and ``**kwargs`` match those passed to ``f``, and retain
the same interpretations and effects on the result. If ``data`` provided to
``fr`` is not an ``xarray.DataArray``, it will fall back on the behaviour
of ``f`` (and ignore the ``broadcast_dimension`` argument).
See Also
--------
broadcastable_method : Convenience alias for ``is_classmethod = True``.
space_broadcastable : Convenience alias for
``only_broadcastable_along = "space"``.
Examples
--------
Make a standalone function broadcast along the ``"space"`` axis of an
``xarray.DataArray``.
>>> @make_broadcastable(is_classmethod=False, only_broadcast_along="space")
... def my_function(xyz_data, *args, **kwargs)
...
... # Call via the usual arguments, replacing the xyz_data argument with
... # the xarray.DataArray to broadcast over
... my_function(data_array, *args, **kwargs)
Make a class method broadcast along any axis of an ``xarray.DataArray``.
>>> from dataclasses import dataclass
>>>
>>> @dataclass
... class MyClass:
... factor: float
... offset: float
...
... @make_broadcastable(is_classmethod=True)
... def manipulate_values(self, xyz_values, *args, **kwargs):
... return self.factor * sum(xyz_values) + self.offset
>>> m = MyClass(factor=5.9, offset=1.0)
>>> m.manipulate_values(
... data_array, *args, broadcast_dimension="time", **kwargs
... )
"""
def decorator(f):
@wraps(f)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> xr.DataArray:
if is_classmethod:
self, data, *rest = args
def call_f(input_1d):
return f(self, input_1d, *rest, **kwargs)
else:
data, *rest = args
def call_f(input_1d):
return f(input_1d, *rest, **kwargs)
if not isinstance(data, xr.DataArray):
return f(*args, **kwargs)
broadcast_dimension = (
only_broadcastable_along
if only_broadcastable_along is not None
else kwargs.pop("broadcast_dimension", "space")
)
return apply_along_da_axis(
call_f,
data,
broadcast_dimension,
new_dimension_name=new_dimension_name,
)
return wrapper
return decorator
[docs]
def space_broadcastable(
*,
is_classmethod: bool = False,
new_dimension_name: str | None = None,
) -> Callable[..., FunctionDaToDa[P] | ClsMethodDaToDa[Self, P]]:
"""Broadcast a 1D function along the 'space' dimension.
This is a convenience wrapper for
``make_broadcastable(only_broadcastable_along='space')``,
and is primarily useful when we want to write a function that acts on
coordinates, that can only be cast across the 'space' dimension of an
``xarray.DataArray``.
Returns
-------
Callable
Callable with signature
``(self,) data, *args, broadcast_dimension = str, **kwargs``,
that applies ``f`` along the ``broadcast_dimension`` of ``data``.
``*args`` and ``**kwargs`` match those passed to ``f``, and
retain the same interpretations.
See Also
--------
make_broadcastable : The aliased decorator function.
"""
return make_broadcastable(
is_classmethod=is_classmethod,
only_broadcastable_along="space",
new_dimension_name=new_dimension_name,
)
[docs]
def broadcastable_method(
only_broadcastable_along: str | None = None,
new_dimension_name: str | None = None,
) -> Callable[
[ClsMethod1DTo1D[Self, P, ScalarOr1D]], ClsMethodDaToDa[Self, P]
]:
"""Broadcast a class method along a ``xarray.DataArray`` dimension.
This is a convenience wrapper for
``make_broadcastable(is_classmethod = True)``,
for use when extending class methods that act on coordinates, that we wish
to cast across the axes of an ``xarray.DataArray``.
Returns
-------
Callable
Callable with signature
``(self,) data, *args, broadcast_dimension = str, **kwargs``,
that applies ``f`` along the ``broadcast_dimension`` of ``data``.
``*args`` and ``**kwargs`` match those passed to ``f``, and
retain the same interpretations.
See Also
--------
make_broadcastable : The aliased decorator function.
"""
return make_broadcastable(
is_classmethod=True,
only_broadcastable_along=only_broadcastable_along,
new_dimension_name=new_dimension_name,
)