.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "examples/broadcasting_your_own_methods.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. or to run this example in your browser via Binder .. rst-class:: sphx-glr-example-title .. _sphx_glr_examples_broadcasting_your_own_methods.py: Extend your analysis methods along data dimensions ===================================================== Learn how to use the ``make_broadcastable`` decorator, to easily cast functions across an entire ``xarray.DataArray``. .. GENERATED FROM PYTHON SOURCE LINES 9-16 Imports ------- We will need ``numpy`` and ``xarray`` to make our custom data for this example, and ``matplotlib`` to show what it contains. We will be using the :mod:`movement.utils.broadcasting` module to turn our one-dimensional functions into functions that work across entire ``DataArray`` objects. .. GENERATED FROM PYTHON SOURCE LINES 18-30 .. code-block:: Python # For interactive plots: install ipympl with `pip install ipympl` and uncomment # the following lines in your notebook # %matplotlib widget import matplotlib.pyplot as plt import numpy as np import xarray as xr from movement import sample_data from movement.plots import plot_centroid_trajectory from movement.utils.broadcasting import make_broadcastable .. GENERATED FROM PYTHON SOURCE LINES 31-36 Load Sample Dataset ------------------- First, we load the ``SLEAP_three-mice_Aeon_proofread`` example dataset. For the rest of this example we'll only need the ``position`` data array, so we store it in a separate variable. .. GENERATED FROM PYTHON SOURCE LINES 36-39 .. code-block:: Python ds = sample_data.fetch_dataset("SLEAP_three-mice_Aeon_proofread.analysis.h5") positions: xr.DataArray = ds.position .. GENERATED FROM PYTHON SOURCE LINES 40-44 The individuals in this dataset follow very similar, arc-like trajectories. To help emphasise what we are doing in this example, we will offset the paths of two of the individuals by a small amount so that the trajectories are more distinct. .. GENERATED FROM PYTHON SOURCE LINES 44-48 .. code-block:: Python positions.loc[:, "y", :, "AEON3B_TP1"] -= 100.0 positions.loc[:, "y", :, "AEON3B_TP2"] += 100.0 .. GENERATED FROM PYTHON SOURCE LINES 49-72 .. code-block:: Python fig, ax = plt.subplots(1, 1) for mouse_name, col in zip( positions.individuals.values, ["r", "g", "b"], strict=False ): plot_centroid_trajectory( positions, individual=mouse_name, keypoints="centroid", ax=ax, linestyle="-", marker=".", s=2, linewidth=0.5, c=col, label=mouse_name, ) ax.invert_yaxis() ax.set_title("Trajectories") ax.set_xlabel("x (pixels)") ax.set_ylabel("y (pixels)") ax.legend() .. image-sg:: /examples/images/sphx_glr_broadcasting_your_own_methods_001.png :alt: Trajectories :srcset: /examples/images/sphx_glr_broadcasting_your_own_methods_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 73-84 Motivation ---------- Suppose that, during our experiment, we have a region of the enclosure that has a slightly wet floor, making it slippery. The individuals must cross this region in order to reach some kind of reward on the other side of the enclosure. We know that the "slippery region" of our enclosure is approximately rectangular in shape, and has its opposite corners at (400, 0) and (600, 2000), where the coordinates are given in pixels. We could then write a function that determines if a given (x, y) position was inside this "slippery region". .. GENERATED FROM PYTHON SOURCE LINES 84-105 .. code-block:: Python def in_slippery_region(xy_position) -> bool: """Return True if xy_position is in the slippery region. Return False otherwise. xy_position has 2 elements, the (x, y) coordinates respectively. """ # The slippery region is a rectangle with the following bounds x_min, y_min = 400.0, 0.0 x_max, y_max = 600.0, 2000.0 is_within_bounds_x = x_min <= xy_position[0] <= x_max is_within_bounds_y = y_min < xy_position[1] <= y_max return is_within_bounds_x and is_within_bounds_y # We can just check our function with a few sample points for point in [(0, 100), (450, 700), (550, 1500), (601, 500)]: print(f"{point} is in slippery region: {in_slippery_region(point)}") .. rst-class:: sphx-glr-script-out .. code-block:: none (0, 100) is in slippery region: False (450, 700) is in slippery region: True (550, 1500) is in slippery region: True (601, 500) is in slippery region: False .. GENERATED FROM PYTHON SOURCE LINES 106-111 Determine if each position was slippery --------------------------------------- Given our data, we could extract whether each position (for each time-point, and each individual) was inside the slippery region by looping over the values. .. GENERATED FROM PYTHON SOURCE LINES 111-153 .. code-block:: Python data_shape = positions.shape in_slippery = np.zeros( shape=( len(positions["time"]), len(positions["keypoints"]), len(positions["individuals"]), ), dtype=bool, ) # We would save one result per time-point, per keypoint, per individual # Feel free to comment out the print statements # (line-by-line progress through the loop), # if you are running this code on your own machine. for time_index, time in enumerate(positions["time"].values): # print(f"At time {time}:") for keypoint_index, keypoint in enumerate(positions["keypoints"].values): # print(f"\tAt keypoint {keypoint}") for individual_index, individual in enumerate( positions["individuals"].values ): xy_point = positions.sel( time=time, keypoints=keypoint, individuals=individual, ) was_in_slippery = in_slippery_region(xy_point) was_in_slippery_text = ( "was in slippery region" if was_in_slippery else "was not in slippery region" ) # print( # "\t\tIndividual " # f"{positions['individuals'].values[individual_index]} " # f"{was_in_slippery_text}" # ) # Save our result to our large array in_slippery[time_index, keypoint_index, individual_index] = ( was_in_slippery ) .. GENERATED FROM PYTHON SOURCE LINES 154-156 We could then build a new ``DataArray`` to store our results, so that we can access the results in the same way that we did our original data. .. GENERATED FROM PYTHON SOURCE LINES 156-172 .. code-block:: Python was_in_slippery_region = xr.DataArray( in_slippery, dims=["time", "keypoints", "individuals"], coords={ "time": positions["time"], "keypoints": positions["keypoints"], "individuals": positions["individuals"], }, ) print( "Boolean DataArray indicating if at a given time, " "a given individual was inside the slippery region:" ) was_in_slippery_region .. rst-class:: sphx-glr-script-out .. code-block:: none Boolean DataArray indicating if at a given time, a given individual was inside the slippery region: .. raw:: html
<xarray.DataArray (time: 601, keypoints: 1, individuals: 3)> Size: 2kB
    False False False False False False ... False False False False False False
    Coordinates:
      * time         (time) float64 5kB 0.0 0.02 0.04 0.06 ... 11.96 11.98 12.0
      * keypoints    (keypoints) <U8 32B 'centroid'
      * individuals  (individuals) <U10 120B 'AEON3B_NTP' 'AEON3B_TP1' 'AEON3B_TP2'


.. GENERATED FROM PYTHON SOURCE LINES 173-175 We could get the first and last time that an individual was inside the slippery region now, by examining this DataArray .. GENERATED FROM PYTHON SOURCE LINES 175-186 .. code-block:: Python i_id = "AEON3B_NTP" individual_0_centroid = was_in_slippery_region.sel( individuals=i_id, keypoints="centroid" ) first_entry = individual_0_centroid["time"][individual_0_centroid].values[0] last_exit = individual_0_centroid["time"][individual_0_centroid].values[-1] print( f"{i_id} first entered the slippery region at " f"{first_entry} and last exited at {last_exit}" ) .. rst-class:: sphx-glr-script-out .. code-block:: none AEON3B_NTP first entered the slippery region at 2.1 and last exited at 8.64 .. GENERATED FROM PYTHON SOURCE LINES 187-200 Data Generalisation Issues -------------------------- The shape of the resulting ``DataArray`` is the same as our original ``DataArray``, but without the ``"space"`` dimension. Indeed, we have essentially collapsed the ``"space"`` dimension, since our ``in_slippery_region`` function takes in a 1D data slice (the x, y positions of a single individual's centroid at a given point in time) and returns a scalar value (True/False). However, the fact that we have to construct a new ``DataArray`` after running our function over all space slices in our ``DataArray`` is not scalable - our ``for`` loop approach relied on knowing how many dimensions our data had (and the size of those dimensions). We don't have a guarantee that the next ``DataArray`` that comes in will have the same structure. .. GENERATED FROM PYTHON SOURCE LINES 202-211 Making our Function Broadcastable --------------------------------- To combat this problem, we can make the observation that given any ``DataArray``, we always want to broadcast our ``in_slippery_region`` function along the ``"space"`` dimension. By "broadcast", we mean that we always want to run our function for each 1D-slice in the ``"space"`` dimension, since these are the (x, y) coordinates. As such, we can decorate our function with the ``make_broadcastable`` decorator: .. GENERATED FROM PYTHON SOURCE LINES 211-218 .. code-block:: Python @make_broadcastable() def in_slippery_region_broadcastable(xy_position) -> float: return in_slippery_region(xy_position=xy_position) .. GENERATED FROM PYTHON SOURCE LINES 219-223 Note that when writing your own methods, there is no need to have both ``in_slippery_region`` and ``in_slippery_region_broadcastable``, simply apply the ``make_broadcastable`` decorator to ``in_slippery_region`` directly. We've made two separate functions here to illustrate what's going on. .. GENERATED FROM PYTHON SOURCE LINES 225-227 ``in_slippery_region_broadcastable`` is usable in exactly the same ways as ``in_slippery_region`` was: .. GENERATED FROM PYTHON SOURCE LINES 227-235 .. code-block:: Python for point in [(0, 100), (450, 700), (550, 1500), (601, 500)]: print( f"{point} is in slippery region: " f"{in_slippery_region_broadcastable(point)}" ) .. rst-class:: sphx-glr-script-out .. code-block:: none (0, 100) is in slippery region: False (450, 700) is in slippery region: True (550, 1500) is in slippery region: True (601, 500) is in slippery region: False .. GENERATED FROM PYTHON SOURCE LINES 236-241 However, ``in_slippery_region_broadcastable`` also takes a ``DataArray`` as the first (``xy_position``) argument, and an extra keyword argument ``broadcast_dimension``. These arguments let us broadcast across the given dimension of the input ``DataArray``, treating each 1D-slice as a separate input to ``in_slippery_region``. .. GENERATED FROM PYTHON SOURCE LINES 241-250 .. code-block:: Python in_slippery_region_broadcasting = in_slippery_region_broadcastable( positions, # Now a DataArray input broadcast_dimension="space", ) print("DataArray output using broadcasting: ") in_slippery_region_broadcasting .. rst-class:: sphx-glr-script-out .. code-block:: none DataArray output using broadcasting: .. raw:: html
<xarray.DataArray 'position' (time: 601, keypoints: 1, individuals: 3)> Size: 2kB
    False False False False False False ... False False False False False False
    Coordinates:
      * time         (time) float64 5kB 0.0 0.02 0.04 0.06 ... 11.96 11.98 12.0
      * keypoints    (keypoints) <U8 32B 'centroid'
      * individuals  (individuals) <U10 120B 'AEON3B_NTP' 'AEON3B_TP1' 'AEON3B_TP2'


.. GENERATED FROM PYTHON SOURCE LINES 251-256 Calling ``in_slippery_region_broadcastable`` in this way gives us a ``DataArray`` output - and one that retains any information that was in our original ``DataArray`` to boot! The result is exactly the same as what we got from using our ``for`` loop, and then adding the extra information to the result. .. GENERATED FROM PYTHON SOURCE LINES 256-262 .. code-block:: Python # Throws an AssertionError if the two inputs are not the same xr.testing.assert_equal( was_in_slippery_region, in_slippery_region_broadcasting ) .. GENERATED FROM PYTHON SOURCE LINES 263-266 But importantly, ``in_slippery_region_broadcastable`` also works on ``DataArrays`` with different dimensions. For example, we could have pre-selected one of our individuals beforehand. .. GENERATED FROM PYTHON SOURCE LINES 266-280 .. code-block:: Python i_id = "AEON3B_NTP" individual_0 = positions.sel(individuals=i_id) individual_0_in_slippery_region = in_slippery_region_broadcastable( individual_0, broadcast_dimension="space", ) print( "We get a 3D DataArray output from our 4D input, " "again with the 'space' dimension that we broadcast along collapsed:" ) individual_0_in_slippery_region .. rst-class:: sphx-glr-script-out .. code-block:: none We get a 3D DataArray output from our 4D input, again with the 'space' dimension that we broadcast along collapsed: .. raw:: html
<xarray.DataArray 'position' (time: 601, keypoints: 1)> Size: 601B
    False False False False False False ... False False False False False False
    Coordinates:
      * time         (time) float64 5kB 0.0 0.02 0.04 0.06 ... 11.96 11.98 12.0
      * keypoints    (keypoints) <U8 32B 'centroid'
        individuals  <U10 40B 'AEON3B_NTP'


.. GENERATED FROM PYTHON SOURCE LINES 281-293 Additional Function Arguments ----------------------------- So far our ``in_slippery_region`` method only takes a single argument, the ``xy_position`` itself. However in follow-up experiments, we might move the slippery region in the enclosure, and so adapt our existing function to make it more general. It will now allow someone to input a custom rectangular region, by specifying the minimum and maximum ``(x, y)`` coordinates of the rectangle, rather than relying on fixed values inside the function. The default region will be the rectangle from our first experiment, and we still want to be able to broadcast this function. And so we write a more general function, as below. .. GENERATED FROM PYTHON SOURCE LINES 293-316 .. code-block:: Python @make_broadcastable() def in_slippery_region_general( xy_position, xy_min=(400.0, 0.0), xy_max=(600.0, 2000.0) ) -> bool: """Return True if xy_position is in the slippery region. Return False otherwise. xy_position has 2 elements, the (x, y) coordinates respectively. """ x_min, y_min = xy_min x_max, y_max = xy_max is_within_bounds_x = x_min <= xy_position[0] <= x_max is_within_bounds_y = y_min <= xy_position[1] <= y_max return is_within_bounds_x and is_within_bounds_y # (0.5, 0.5) is in the unit square whose bottom left corner is at the origin print(in_slippery_region_general((0.5, 0.5), (0.0, 0.0), (1.0, 1.0))) # But (0.5,0.5) is not in a unit square whose bottom left corner is at (1,1) print(in_slippery_region_general((0.5, 0.5), (1.0, 1.0), (2.0, 2.0))) .. rst-class:: sphx-glr-script-out .. code-block:: none True False .. GENERATED FROM PYTHON SOURCE LINES 317-320 We will find that ``make_broadcastable`` retains the additional arguments to the function we define, however the ``xy_position`` argument has to be the first argument to the function, that appears in the ``def`` statement. .. GENERATED FROM PYTHON SOURCE LINES 320-329 .. code-block:: Python # Default arguments should give us the same results as before xr.testing.assert_equal( was_in_slippery_region, in_slippery_region_general(positions) ) # But we can also provide the optional arguments in the same way as with the # un-decorated function. in_slippery_region_general(positions, xy_min=(100, 0), xy_max=(400, 1000)) .. raw:: html
<xarray.DataArray 'position' (time: 601, keypoints: 1, individuals: 3)> Size: 2kB
    False False True False False True False ... True True False True True False
    Coordinates:
      * time         (time) float64 5kB 0.0 0.02 0.04 0.06 ... 11.96 11.98 12.0
      * keypoints    (keypoints) <U8 32B 'centroid'
      * individuals  (individuals) <U10 120B 'AEON3B_NTP' 'AEON3B_TP1' 'AEON3B_TP2'


.. GENERATED FROM PYTHON SOURCE LINES 330-336 Only Broadcast Along Select Dimensions -------------------------------------- The ``make_broadcastable`` decorator has some flexibility with its input arguments, to help you avoid unintentional behaviour. You may have noticed, for example, that there is nothing stopping someone who wants to use your analysis code from trying to broadcast along the wrong dimension. .. GENERATED FROM PYTHON SOURCE LINES 336-344 .. code-block:: Python silly_broadcast = in_slippery_region_broadcastable( positions, broadcast_dimension="time" ) print("The output has collapsed the time dimension:") silly_broadcast .. rst-class:: sphx-glr-script-out .. code-block:: none The output has collapsed the time dimension: .. raw:: html
<xarray.DataArray 'position' (space: 2, keypoints: 1, individuals: 3)> Size: 6B
    False False False False False False
    Coordinates:
      * space        (space) <U1 8B 'x' 'y'
      * keypoints    (keypoints) <U8 32B 'centroid'
      * individuals  (individuals) <U10 120B 'AEON3B_NTP' 'AEON3B_TP1' 'AEON3B_TP2'


.. GENERATED FROM PYTHON SOURCE LINES 345-355 There is no error thrown because functionally, this is a valid operation. The time slices of our data were 1D, so we can run ``in_slippery_region`` on them. But each slice isn't a position, it's an array of one spatial coordinate (EG x) for each keypoint, each individual, at every time! So from an analysis standpoint, doing this doesn't make sense and isn't how we intend our function to be used. We can pass the ``only_broadcastable_along`` keyword argument to ``make_broadcastable`` to prevent these kinds of mistakes, and make our intentions clearer. .. GENERATED FROM PYTHON SOURCE LINES 355-362 .. code-block:: Python @make_broadcastable(only_broadcastable_along="space") def in_slippery_region_space_only(xy_position): return in_slippery_region(xy_position) .. GENERATED FROM PYTHON SOURCE LINES 363-365 Now, ``in_slippery_region_space_only`` no longer takes the ``broadcast_dimension`` argument. .. GENERATED FROM PYTHON SOURCE LINES 365-374 .. code-block:: Python try: in_slippery_region_space_only( positions, broadcast_dimension="time", ) except TypeError as e: print(f"Got a TypeError when trying to run, here's the message:\n{e}") .. rst-class:: sphx-glr-script-out .. code-block:: none Got a TypeError when trying to run, here's the message: __main__.in_slippery_region_space_only() got multiple values for keyword argument 'broadcast_dimension' .. GENERATED FROM PYTHON SOURCE LINES 375-382 The error we get seems to be telling us that we've tried to set the value of ``broadcast_dimension`` twice. Specifying ``only_broadcastable_along = "space"`` forces ``broadcast_dimension`` to be set to ``"space"``, so trying to set it again (even to to the same value) results in an error. However, ``in_slippery_region_space_only`` knows to only use the ``"space"`` dimension of the input by default. .. GENERATED FROM PYTHON SOURCE LINES 382-389 .. code-block:: Python was_in_view_space_only = in_slippery_region_space_only(positions) xr.testing.assert_equal( in_slippery_region_broadcasting, was_in_view_space_only ) .. GENERATED FROM PYTHON SOURCE LINES 390-394 It is worth noting that there is a "helper" decorator, ``space_broadcastable``, that essentially does the same thing as ``make_broadcastable(only_broadcastable_along="space")``. You can use this decorator for your own convenience. .. GENERATED FROM PYTHON SOURCE LINES 396-400 Extending to Class Methods -------------------------- ``make_broadcastable`` can also be applied to class methods, though it needs to be told that you are doing so via the ``is_classmethod`` parameter. .. GENERATED FROM PYTHON SOURCE LINES 400-430 .. code-block:: Python class Rectangle: """Represents an observing camera in the experiment.""" xy_min: tuple[float, float] xy_max: tuple[float, float] def __init__(self, xy_min=(0.0, 0.0), xy_max=(1.0, 1.0)): """Create a new instance.""" self.xy_min = tuple(xy_min) self.xy_max = tuple(xy_max) @make_broadcastable(is_classmethod=True, only_broadcastable_along="space") def is_inside(self, /, xy_position) -> bool: """Whether the position is inside the rectangle.""" # For the sake of brevity, we won't redefine the entire method here, # and will just call our existing function. return in_slippery_region_general( xy_position, self.xy_min, self.xy_max ) slippery_region = Rectangle(xy_min=(400.0, 0.0), xy_max=(600.0, 2000.0)) was_in_region_clsmethod = slippery_region.is_inside(positions) xr.testing.assert_equal( was_in_region_clsmethod, in_slippery_region_broadcasting ) .. GENERATED FROM PYTHON SOURCE LINES 431-434 The ``broadcastable_method`` decorator is provided as a helpful alias for ``make_broadcastable(is_classmethod=True)``, and otherwise works in the same way (and accepts the same parameters). .. GENERATED FROM PYTHON SOURCE LINES 434-468 .. code-block:: Python class RectangleAlternative: """Represents an observing camera in the experiment.""" xy_min: tuple[float, float] xy_max: tuple[float, float] def __init__(self, xy_min=(0.0, 0.0), xy_max=(1.0, 1.0)): """Create a new instance.""" self.xy_min = tuple(xy_min) self.xy_max = tuple(xy_max) @make_broadcastable(is_classmethod=True, only_broadcastable_along="space") def is_inside(self, /, xy_position) -> bool: """Whether the position is inside the rectangle.""" # For the sake of brevity, we won't redefine the entire method here, # and will just call our existing function. return in_slippery_region_general( xy_position, self.xy_min, self.xy_max ) slippery_region_alt = RectangleAlternative( xy_min=(400.0, 0.0), xy_max=(600.0, 2000.0) ) was_in_region_clsmethod_alt = slippery_region.is_inside(positions) xr.testing.assert_equal( was_in_region_clsmethod_alt, in_slippery_region_broadcasting ) xr.testing.assert_equal(was_in_region_clsmethod_alt, was_in_region_clsmethod) .. GENERATED FROM PYTHON SOURCE LINES 469-472 In fact, if you look at the Regions of Interest submodule, and in particular the classes inside it, you'll notice that we use the ``broadcastable_method`` decorator ourselves in some of these methods! .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 1.129 seconds) .. _sphx_glr_download_examples_broadcasting_your_own_methods.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: binder-badge .. image:: images/binder_badge_logo.svg :target: https://mybinder.org/v2/gh/neuroinformatics-unit/movement/gh-pages?filepath=notebooks/examples/broadcasting_your_own_methods.ipynb :alt: Launch binder :width: 150 px .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: broadcasting_your_own_methods.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: broadcasting_your_own_methods.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: broadcasting_your_own_methods.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_