.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "examples/advanced/broadcasting_your_own_methods.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. or to run this example in your browser via Binder .. rst-class:: sphx-glr-example-title .. _sphx_glr_examples_advanced_broadcasting_your_own_methods.py: Broadcast functions across multi-dimensional data ===================================================== Use the ``make_broadcastable`` decorator to efficiently apply functions across any data dimension. .. GENERATED FROM PYTHON SOURCE LINES 9-19 Summary ------- The ``make_broadcastable`` decorator is particularly useful when you need to apply the same operation to multiple individuals or time points while avoiding the need to write complex loops. The example walks through a practical case study of detecting when animals enter a specific region of interest, showing how to convert a simple point-in-rectangle check into a function that works on a data array with many time-varying point trajectories. .. GENERATED FROM PYTHON SOURCE LINES 21-28 Imports ------- We will need ``numpy`` and ``xarray`` to make our custom data for this example, and ``matplotlib`` to show what it contains. We will be using the :mod:`movement.utils.broadcasting` module to turn our one-dimensional functions into functions that work across entire ``DataArray`` objects. .. GENERATED FROM PYTHON SOURCE LINES 30-42 .. code-block:: Python # For interactive plots: install ipympl with `pip install ipympl` and uncomment # the following lines in your notebook # %matplotlib widget import matplotlib.pyplot as plt import numpy as np import xarray as xr from movement import sample_data from movement.plots import plot_centroid_trajectory from movement.utils.broadcasting import make_broadcastable .. GENERATED FROM PYTHON SOURCE LINES 43-48 Load Sample Dataset ------------------- First, we load the ``SLEAP_three-mice_Aeon_proofread`` example dataset. For the rest of this example we'll only need the ``position`` data array, so we store it in a separate variable. .. GENERATED FROM PYTHON SOURCE LINES 48-51 .. code-block:: Python ds = sample_data.fetch_dataset("SLEAP_three-mice_Aeon_proofread.analysis.h5") positions: xr.DataArray = ds.position .. GENERATED FROM PYTHON SOURCE LINES 52-56 The individuals in this dataset follow very similar, arc-like trajectories. To help emphasise what we are doing in this example, we will offset the paths of two of the individuals by a small amount so that the trajectories are more distinct. .. GENERATED FROM PYTHON SOURCE LINES 56-60 .. code-block:: Python positions.loc[:, "y", :, "AEON3B_TP1"] -= 100.0 positions.loc[:, "y", :, "AEON3B_TP2"] += 100.0 .. GENERATED FROM PYTHON SOURCE LINES 61-84 .. code-block:: Python fig, ax = plt.subplots(1, 1) for mouse_name, col in zip( positions.individuals.values, ["r", "g", "b"], strict=False ): plot_centroid_trajectory( positions, individual=mouse_name, keypoints="centroid", ax=ax, linestyle="-", marker=".", s=2, linewidth=0.5, c=col, label=mouse_name, ) ax.invert_yaxis() ax.set_title("Trajectories") ax.set_xlabel("x (pixels)") ax.set_ylabel("y (pixels)") ax.legend() .. image-sg:: /examples/advanced/images/sphx_glr_broadcasting_your_own_methods_001.png :alt: Trajectories :srcset: /examples/advanced/images/sphx_glr_broadcasting_your_own_methods_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 85-96 Motivation ---------- Suppose that, during our experiment, we have a region of the enclosure that has a slightly wet floor, making it slippery. The individuals must cross this region in order to reach some kind of reward on the other side of the enclosure. We know that the "slippery region" of our enclosure is approximately rectangular in shape, and has its opposite corners at (400, 0) and (600, 2000), where the coordinates are given in pixels. We could then write a function that determines if a given (x, y) position was inside this "slippery region". .. GENERATED FROM PYTHON SOURCE LINES 96-117 .. code-block:: Python def in_slippery_region(xy_position) -> bool: """Return True if xy_position is in the slippery region. Return False otherwise. xy_position has 2 elements, the (x, y) coordinates respectively. """ # The slippery region is a rectangle with the following bounds x_min, y_min = 400.0, 0.0 x_max, y_max = 600.0, 2000.0 is_within_bounds_x = x_min <= xy_position[0] <= x_max is_within_bounds_y = y_min < xy_position[1] <= y_max return is_within_bounds_x and is_within_bounds_y # We can just check our function with a few sample points for point in [(0, 100), (450, 700), (550, 1500), (601, 500)]: print(f"{point} is in slippery region: {in_slippery_region(point)}") .. rst-class:: sphx-glr-script-out .. code-block:: none (0, 100) is in slippery region: False (450, 700) is in slippery region: True (550, 1500) is in slippery region: True (601, 500) is in slippery region: False .. GENERATED FROM PYTHON SOURCE LINES 118-123 Determine if each position was slippery --------------------------------------- Given our data, we could extract whether each position (for each time-point, and each individual) was inside the slippery region by looping over the values. .. GENERATED FROM PYTHON SOURCE LINES 123-165 .. code-block:: Python data_shape = positions.shape in_slippery = np.zeros( shape=( len(positions["time"]), len(positions["keypoints"]), len(positions["individuals"]), ), dtype=bool, ) # We would save one result per time-point, per keypoint, per individual # Feel free to comment out the print statements # (line-by-line progress through the loop), # if you are running this code on your own machine. for time_index, time in enumerate(positions["time"].values): # print(f"At time {time}:") for keypoint_index, keypoint in enumerate(positions["keypoints"].values): # print(f"\tAt keypoint {keypoint}") for individual_index, individual in enumerate( positions["individuals"].values ): xy_point = positions.sel( time=time, keypoints=keypoint, individuals=individual, ) was_in_slippery = in_slippery_region(xy_point) was_in_slippery_text = ( "was in slippery region" if was_in_slippery else "was not in slippery region" ) # print( # "\t\tIndividual " # f"{positions['individuals'].values[individual_index]} " # f"{was_in_slippery_text}" # ) # Save our result to our large array in_slippery[time_index, keypoint_index, individual_index] = ( was_in_slippery ) .. GENERATED FROM PYTHON SOURCE LINES 166-168 We could then build a new ``DataArray`` to store our results, so that we can access the results in the same way that we did our original data. .. GENERATED FROM PYTHON SOURCE LINES 168-184 .. code-block:: Python was_in_slippery_region = xr.DataArray( in_slippery, dims=["time", "keypoints", "individuals"], coords={ "time": positions["time"], "keypoints": positions["keypoints"], "individuals": positions["individuals"], }, ) print( "Boolean DataArray indicating if at a given time, " "a given individual was inside the slippery region:" ) was_in_slippery_region .. rst-class:: sphx-glr-script-out .. code-block:: none Boolean DataArray indicating if at a given time, a given individual was inside the slippery region: .. raw:: html
<xarray.DataArray (time: 601, keypoints: 1, individuals: 3)> Size: 2kB
    False False False False False False ... False False False False False False
    Coordinates:
      * time         (time) float64 5kB 0.0 0.02 0.04 0.06 ... 11.96 11.98 12.0
      * keypoints    (keypoints) <U8 32B 'centroid'
      * individuals  (individuals) <U10 120B 'AEON3B_NTP' 'AEON3B_TP1' 'AEON3B_TP2'


.. GENERATED FROM PYTHON SOURCE LINES 185-187 We could get the first and last time that an individual was inside the slippery region now, by examining this DataArray .. GENERATED FROM PYTHON SOURCE LINES 187-198 .. code-block:: Python i_id = "AEON3B_NTP" individual_0_centroid = was_in_slippery_region.sel( individuals=i_id, keypoints="centroid" ) first_entry = individual_0_centroid["time"][individual_0_centroid].values[0] last_exit = individual_0_centroid["time"][individual_0_centroid].values[-1] print( f"{i_id} first entered the slippery region at " f"{first_entry} and last exited at {last_exit}" ) .. rst-class:: sphx-glr-script-out .. code-block:: none AEON3B_NTP first entered the slippery region at 2.1 and last exited at 8.64 .. GENERATED FROM PYTHON SOURCE LINES 199-212 Data Generalisation Issues -------------------------- The shape of the resulting ``DataArray`` is the same as our original ``DataArray``, but without the ``"space"`` dimension. Indeed, we have essentially collapsed the ``"space"`` dimension, since our ``in_slippery_region`` function takes in a 1D data slice (the x, y positions of a single individual's centroid at a given point in time) and returns a scalar value (True/False). However, the fact that we have to construct a new ``DataArray`` after running our function over all space slices in our ``DataArray`` is not scalable - our ``for`` loop approach relied on knowing how many dimensions our data had (and the size of those dimensions). We don't have a guarantee that the next ``DataArray`` that comes in will have the same structure. .. GENERATED FROM PYTHON SOURCE LINES 214-223 Making our Function Broadcastable --------------------------------- To combat this problem, we can make the observation that given any ``DataArray``, we always want to broadcast our ``in_slippery_region`` function along the ``"space"`` dimension. By "broadcast", we mean that we always want to run our function for each 1D-slice in the ``"space"`` dimension, since these are the (x, y) coordinates. As such, we can decorate our function with the ``make_broadcastable`` decorator: .. GENERATED FROM PYTHON SOURCE LINES 223-230 .. code-block:: Python @make_broadcastable() def in_slippery_region_broadcastable(xy_position) -> float: return in_slippery_region(xy_position=xy_position) .. GENERATED FROM PYTHON SOURCE LINES 231-235 Note that when writing your own methods, there is no need to have both ``in_slippery_region`` and ``in_slippery_region_broadcastable``, simply apply the ``make_broadcastable`` decorator to ``in_slippery_region`` directly. We've made two separate functions here to illustrate what's going on. .. GENERATED FROM PYTHON SOURCE LINES 237-239 ``in_slippery_region_broadcastable`` is usable in exactly the same ways as ``in_slippery_region`` was: .. GENERATED FROM PYTHON SOURCE LINES 239-247 .. code-block:: Python for point in [(0, 100), (450, 700), (550, 1500), (601, 500)]: print( f"{point} is in slippery region: " f"{in_slippery_region_broadcastable(point)}" ) .. rst-class:: sphx-glr-script-out .. code-block:: none (0, 100) is in slippery region: False (450, 700) is in slippery region: True (550, 1500) is in slippery region: True (601, 500) is in slippery region: False .. GENERATED FROM PYTHON SOURCE LINES 248-253 However, ``in_slippery_region_broadcastable`` also takes a ``DataArray`` as the first (``xy_position``) argument, and an extra keyword argument ``broadcast_dimension``. These arguments let us broadcast across the given dimension of the input ``DataArray``, treating each 1D-slice as a separate input to ``in_slippery_region``. .. GENERATED FROM PYTHON SOURCE LINES 253-262 .. code-block:: Python in_slippery_region_broadcasting = in_slippery_region_broadcastable( positions, # Now a DataArray input broadcast_dimension="space", ) print("DataArray output using broadcasting: ") in_slippery_region_broadcasting .. rst-class:: sphx-glr-script-out .. code-block:: none DataArray output using broadcasting: .. raw:: html
<xarray.DataArray 'position' (time: 601, keypoints: 1, individuals: 3)> Size: 2kB
    False False False False False False ... False False False False False False
    Coordinates:
      * time         (time) float64 5kB 0.0 0.02 0.04 0.06 ... 11.96 11.98 12.0
      * keypoints    (keypoints) <U8 32B 'centroid'
      * individuals  (individuals) <U10 120B 'AEON3B_NTP' 'AEON3B_TP1' 'AEON3B_TP2'


.. GENERATED FROM PYTHON SOURCE LINES 263-268 Calling ``in_slippery_region_broadcastable`` in this way gives us a ``DataArray`` output - and one that retains any information that was in our original ``DataArray`` to boot! The result is exactly the same as what we got from using our ``for`` loop, and then adding the extra information to the result. .. GENERATED FROM PYTHON SOURCE LINES 268-274 .. code-block:: Python # Throws an AssertionError if the two inputs are not the same xr.testing.assert_equal( was_in_slippery_region, in_slippery_region_broadcasting ) .. GENERATED FROM PYTHON SOURCE LINES 275-278 But importantly, ``in_slippery_region_broadcastable`` also works on ``DataArrays`` with different dimensions. For example, we could have pre-selected one of our individuals beforehand. .. GENERATED FROM PYTHON SOURCE LINES 278-292 .. code-block:: Python i_id = "AEON3B_NTP" individual_0 = positions.sel(individuals=i_id) individual_0_in_slippery_region = in_slippery_region_broadcastable( individual_0, broadcast_dimension="space", ) print( "We get a 3D DataArray output from our 4D input, " "again with the 'space' dimension that we broadcast along collapsed:" ) individual_0_in_slippery_region .. rst-class:: sphx-glr-script-out .. code-block:: none We get a 3D DataArray output from our 4D input, again with the 'space' dimension that we broadcast along collapsed: .. raw:: html
<xarray.DataArray 'position' (time: 601, keypoints: 1)> Size: 601B
    False False False False False False ... False False False False False False
    Coordinates:
      * time         (time) float64 5kB 0.0 0.02 0.04 0.06 ... 11.96 11.98 12.0
      * keypoints    (keypoints) <U8 32B 'centroid'
        individuals  <U10 40B 'AEON3B_NTP'


.. GENERATED FROM PYTHON SOURCE LINES 293-305 Additional Function Arguments ----------------------------- So far our ``in_slippery_region`` method only takes a single argument, the ``xy_position`` itself. However in follow-up experiments, we might move the slippery region in the enclosure, and so adapt our existing function to make it more general. It will now allow someone to input a custom rectangular region, by specifying the minimum and maximum ``(x, y)`` coordinates of the rectangle, rather than relying on fixed values inside the function. The default region will be the rectangle from our first experiment, and we still want to be able to broadcast this function. And so we write a more general function, as below. .. GENERATED FROM PYTHON SOURCE LINES 305-328 .. code-block:: Python @make_broadcastable() def in_slippery_region_general( xy_position, xy_min=(400.0, 0.0), xy_max=(600.0, 2000.0) ) -> bool: """Return True if xy_position is in the slippery region. Return False otherwise. xy_position has 2 elements, the (x, y) coordinates respectively. """ x_min, y_min = xy_min x_max, y_max = xy_max is_within_bounds_x = x_min <= xy_position[0] <= x_max is_within_bounds_y = y_min <= xy_position[1] <= y_max return is_within_bounds_x and is_within_bounds_y # (0.5, 0.5) is in the unit square whose bottom left corner is at the origin print(in_slippery_region_general((0.5, 0.5), (0.0, 0.0), (1.0, 1.0))) # But (0.5,0.5) is not in a unit square whose bottom left corner is at (1,1) print(in_slippery_region_general((0.5, 0.5), (1.0, 1.0), (2.0, 2.0))) .. rst-class:: sphx-glr-script-out .. code-block:: none True False .. GENERATED FROM PYTHON SOURCE LINES 329-332 We will find that ``make_broadcastable`` retains the additional arguments to the function we define, however the ``xy_position`` argument has to be the first argument to the function, that appears in the ``def`` statement. .. GENERATED FROM PYTHON SOURCE LINES 332-341 .. code-block:: Python # Default arguments should give us the same results as before xr.testing.assert_equal( was_in_slippery_region, in_slippery_region_general(positions) ) # But we can also provide the optional arguments in the same way as with the # un-decorated function. in_slippery_region_general(positions, xy_min=(100, 0), xy_max=(400, 1000)) .. raw:: html
<xarray.DataArray 'position' (time: 601, keypoints: 1, individuals: 3)> Size: 2kB
    False False True False False True False ... True True False True True False
    Coordinates:
      * time         (time) float64 5kB 0.0 0.02 0.04 0.06 ... 11.96 11.98 12.0
      * keypoints    (keypoints) <U8 32B 'centroid'
      * individuals  (individuals) <U10 120B 'AEON3B_NTP' 'AEON3B_TP1' 'AEON3B_TP2'


.. GENERATED FROM PYTHON SOURCE LINES 342-348 Only Broadcast Along Select Dimensions -------------------------------------- The ``make_broadcastable`` decorator has some flexibility with its input arguments, to help you avoid unintentional behaviour. You may have noticed, for example, that there is nothing stopping someone who wants to use your analysis code from trying to broadcast along the wrong dimension. .. GENERATED FROM PYTHON SOURCE LINES 348-356 .. code-block:: Python silly_broadcast = in_slippery_region_broadcastable( positions, broadcast_dimension="time" ) print("The output has collapsed the time dimension:") silly_broadcast .. rst-class:: sphx-glr-script-out .. code-block:: none The output has collapsed the time dimension: .. raw:: html
<xarray.DataArray 'position' (space: 2, keypoints: 1, individuals: 3)> Size: 6B
    False False False False False False
    Coordinates:
      * space        (space) <U1 8B 'x' 'y'
      * keypoints    (keypoints) <U8 32B 'centroid'
      * individuals  (individuals) <U10 120B 'AEON3B_NTP' 'AEON3B_TP1' 'AEON3B_TP2'


.. GENERATED FROM PYTHON SOURCE LINES 357-367 There is no error thrown because functionally, this is a valid operation. The time slices of our data were 1D, so we can run ``in_slippery_region`` on them. But each slice isn't a position, it's an array of one spatial coordinate (EG x) for each keypoint, each individual, at every time! So from an analysis standpoint, doing this doesn't make sense and isn't how we intend our function to be used. We can pass the ``only_broadcastable_along`` keyword argument to ``make_broadcastable`` to prevent these kinds of mistakes, and make our intentions clearer. .. GENERATED FROM PYTHON SOURCE LINES 367-374 .. code-block:: Python @make_broadcastable(only_broadcastable_along="space") def in_slippery_region_space_only(xy_position): return in_slippery_region(xy_position) .. GENERATED FROM PYTHON SOURCE LINES 375-377 Now, ``in_slippery_region_space_only`` no longer takes the ``broadcast_dimension`` argument. .. GENERATED FROM PYTHON SOURCE LINES 377-386 .. code-block:: Python try: in_slippery_region_space_only( positions, broadcast_dimension="time", ) except TypeError as e: print(f"Got a TypeError when trying to run, here's the message:\n{e}") .. rst-class:: sphx-glr-script-out .. code-block:: none Got a TypeError when trying to run, here's the message: __main__.in_slippery_region_space_only() got multiple values for keyword argument 'broadcast_dimension' .. GENERATED FROM PYTHON SOURCE LINES 387-394 The error we get seems to be telling us that we've tried to set the value of ``broadcast_dimension`` twice. Specifying ``only_broadcastable_along = "space"`` forces ``broadcast_dimension`` to be set to ``"space"``, so trying to set it again (even to to the same value) results in an error. However, ``in_slippery_region_space_only`` knows to only use the ``"space"`` dimension of the input by default. .. GENERATED FROM PYTHON SOURCE LINES 394-401 .. code-block:: Python was_in_view_space_only = in_slippery_region_space_only(positions) xr.testing.assert_equal( in_slippery_region_broadcasting, was_in_view_space_only ) .. GENERATED FROM PYTHON SOURCE LINES 402-406 It is worth noting that there is a "helper" decorator, ``space_broadcastable``, that essentially does the same thing as ``make_broadcastable(only_broadcastable_along="space")``. You can use this decorator for your own convenience. .. GENERATED FROM PYTHON SOURCE LINES 408-412 Extending to Class Methods -------------------------- ``make_broadcastable`` can also be applied to class methods, though it needs to be told that you are doing so via the ``is_classmethod`` parameter. .. GENERATED FROM PYTHON SOURCE LINES 412-442 .. code-block:: Python class Rectangle: """Represents an observing camera in the experiment.""" xy_min: tuple[float, float] xy_max: tuple[float, float] def __init__(self, xy_min=(0.0, 0.0), xy_max=(1.0, 1.0)): """Create a new instance.""" self.xy_min = tuple(xy_min) self.xy_max = tuple(xy_max) @make_broadcastable(is_classmethod=True, only_broadcastable_along="space") def is_inside(self, /, xy_position) -> bool: """Whether the position is inside the rectangle.""" # For the sake of brevity, we won't redefine the entire method here, # and will just call our existing function. return in_slippery_region_general( xy_position, self.xy_min, self.xy_max ) slippery_region = Rectangle(xy_min=(400.0, 0.0), xy_max=(600.0, 2000.0)) was_in_region_clsmethod = slippery_region.is_inside(positions) xr.testing.assert_equal( was_in_region_clsmethod, in_slippery_region_broadcasting ) .. GENERATED FROM PYTHON SOURCE LINES 443-446 The ``broadcastable_method`` decorator is provided as a helpful alias for ``make_broadcastable(is_classmethod=True)``, and otherwise works in the same way (and accepts the same parameters). .. GENERATED FROM PYTHON SOURCE LINES 446-480 .. code-block:: Python class RectangleAlternative: """Represents an observing camera in the experiment.""" xy_min: tuple[float, float] xy_max: tuple[float, float] def __init__(self, xy_min=(0.0, 0.0), xy_max=(1.0, 1.0)): """Create a new instance.""" self.xy_min = tuple(xy_min) self.xy_max = tuple(xy_max) @make_broadcastable(is_classmethod=True, only_broadcastable_along="space") def is_inside(self, /, xy_position) -> bool: """Whether the position is inside the rectangle.""" # For the sake of brevity, we won't redefine the entire method here, # and will just call our existing function. return in_slippery_region_general( xy_position, self.xy_min, self.xy_max ) slippery_region_alt = RectangleAlternative( xy_min=(400.0, 0.0), xy_max=(600.0, 2000.0) ) was_in_region_clsmethod_alt = slippery_region.is_inside(positions) xr.testing.assert_equal( was_in_region_clsmethod_alt, in_slippery_region_broadcasting ) xr.testing.assert_equal(was_in_region_clsmethod_alt, was_in_region_clsmethod) .. GENERATED FROM PYTHON SOURCE LINES 481-484 In fact, if you look at the Regions of Interest submodule, and in particular the classes inside it, you'll notice that we use the ``broadcastable_method`` decorator ourselves in some of these methods! .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 1.109 seconds) .. _sphx_glr_download_examples_advanced_broadcasting_your_own_methods.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: binder-badge .. image:: images/binder_badge_logo.svg :target: https://mybinder.org/v2/gh/neuroinformatics-unit/movement/gh-pages?filepath=notebooks/examples/advanced/broadcasting_your_own_methods.ipynb :alt: Launch binder :width: 150 px .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: broadcasting_your_own_methods.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: broadcasting_your_own_methods.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: broadcasting_your_own_methods.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_