TypedIterator#

This example shows how to use the TypedIterator class, which might be helpful, when iterating over data and needing to store multiple results for each iteration.

The Problem#

A very common pattern when working with any type of data is to iterate over it and then apply a series of operations to it. In simple cases you might only want to store the final result, but often you are also interested in intermediate or alternative outputs.

What typically happens, is that you create multiple empty lists or dictionaries (one for each result) and then append the results to them during the iteration. At the end you might apply further operations to the results, e.g. aggregations.

Below is a simple example of this pattern:

45
[6, 12, 18, 24, 30]
[2, 8, 14, 20, 26]

Fundamentally, this pattern works well. However, it does not really fit into the idea of declarative code that we are trying to achieve with tpcp. While programming, there are 3 places where you need to think about the result and the result types. This makes it harder to reason about the code and also makes it harder to change the code later on. In addition, the main pipeline code, which should be the most important part of the code, is cluttered with boilerplate code concerned with just storing the results.

While we could fix some of these issues by refactoring a little, with TypedIterator we provide (in our opinion) a much cleaner solution.

The basic idea of TypedIterator is to provide a way to specify all configuration (i.e. what results to expect and how to aggregate them) in one place at the beginning. It further simplifies how to store results, by inverting the data structure. Instead of worrying about one data structure for each result, you only need to worry about one data structure for each iteration. Using dataclasses, these objects are also typed, preventing typos and providing IDE support.

Let’s rewrite the above example using TypedIterator:

  1. We define our result-datatype as a dataclass.

from dataclasses import dataclass


@dataclass
class ResultType:
    result_1: int
    result_2: int
    result_3: int
  1. We define the aggregations we want to apply to the results. If we don’t want to aggregate a result, we simply don’t add it to the list. We provide some more explanation on aggregations below, just accept this for now.

aggregations = [
    ("result_1", lambda _, results: sum(results)),
]
  1. We create a new instance of TypedIterator with the result type and the aggregations.

from tpcp.misc import TypedIterator

iterator = TypedIterator(ResultType, aggregations=aggregations)

Now we can iterate over our data and get a result object for each iteration, that we can then fill with the results.

for d, r in iterator.iterate(data):
    r.result_1 = d * 3
    r.result_2 = r.result_1 * 2
    r.result_3 = r.result_2 - 4

The aggregated results are now available as attributes of the iterator.

45
[6, 12, 18, 24, 30]
[2, 8, 14, 20, 26]

The raw results are available as a list of dataclass instances.

iterator.raw_results_
[ResultType(result_1=3, result_2=6, result_3=2), ResultType(result_1=6, result_2=12, result_3=8), ResultType(result_1=9, result_2=18, result_3=14), ResultType(result_1=12, result_2=24, result_3=20), ResultType(result_1=15, result_2=30, result_3=26)]

While this version of the code required a couple more lines, it is much easier to understand and reason about. It clearly separates the configuration from the actual code and the core pipeline code is much cleaner.

A real-world example#

Below we apply this pattern to a pipeline that iterates over an actual dataset. The return types are a little bit more complex to show some more advanced features of aggregations.

For this example we apply the QRS detection algorithm to the ECG dataset demonstrated in some of the other examples. The QRS detection algorithm only has a single output. Hence, we use the “number of r-peaks” as a second result here to demonstrate the use case.

Again we start by defining the result dataclass.

import pandas as pd


@dataclass
class QRSResultType:
    """The result type of the QRS detection algorithm."""

    r_peak_positions: pd.Series
    n_r_peaks: int

For the aggregations, we want to concatenate the r-peak positions. The aggregation function gets the list of inputs as the first argument and the list of results as the second argument. We can use this to create a combined dataframe with a proper index.

We turn the n_r_peaks into a dictionary, to make it easier to map the results back to the inputs.

aggregations = [
    (
        "r_peak_positions",
        lambda datapoints, results: pd.concat(results, keys=[d.group_label for d in datapoints]),
    ),
    (
        "n_r_peaks",
        lambda datapoints, results: dict(zip([d.group_label for d in datapoints], results)),
    ),
]

Now we can create the iterator and iterate over the dataset.

from pathlib import Path

from examples.algorithms.algorithms_qrs_detection_final import QRSDetector
from examples.datasets.datasets_final_ecg import ECGExampleData

iterator = TypedIterator(QRSResultType, aggregations=aggregations)

try:
    HERE = Path(__file__).parent
except NameError:
    HERE = Path().resolve()
data_path = HERE.parent.parent / "example_data/ecg_mit_bih_arrhythmia/data"

dataset = ECGExampleData(data_path)

for d, r in iterator.iterate(dataset):
    r.r_peak_positions = QRSDetector().detect(d.data["ecg"], sampling_rate_hz=d.sampling_rate_hz).r_peak_positions_
    r.n_r_peaks = len(r.r_peak_positions)

Finally we can inspect the results stored on the iterator. Note, that r_peak_positions_ is a single dataframe now and not a list of dataframes.

group_1  100  0           77
              1          370
              2          663
              3          947
              4         1231
                       ...
group_3  200  1448    647546
              1449    648357
              1450    648629
              1451    649409
              1452    649928
Length: 17782, dtype: int64

The n_r_peaks_ is still a dictionary, as excpected.

{ECGExampleDataGroupLabel(patient_group='group_1', participant='100'): 2270, ECGExampleDataGroupLabel(patient_group='group_2', participant='102'): 1710, ECGExampleDataGroupLabel(patient_group='group_3', participant='104'): 2066, ECGExampleDataGroupLabel(patient_group='group_1', participant='105'): 2567, ECGExampleDataGroupLabel(patient_group='group_2', participant='106'): 1704, ECGExampleDataGroupLabel(patient_group='group_3', participant='108'): 78, ECGExampleDataGroupLabel(patient_group='group_1', participant='114'): 30, ECGExampleDataGroupLabel(patient_group='group_2', participant='116'): 2392, ECGExampleDataGroupLabel(patient_group='group_3', participant='119'): 1988, ECGExampleDataGroupLabel(patient_group='group_1', participant='121'): 6, ECGExampleDataGroupLabel(patient_group='group_2', participant='123'): 1518, ECGExampleDataGroupLabel(patient_group='group_3', participant='200'): 1453}

The raw results are still available a list of dataclass instances.

iterator.raw_results_
[QRSResultType(r_peak_positions=0           77
1          370
2          663
3          947
4         1231
         ...
2265    648978
2266    649232
2267    649485
2268    649734
2269    649992
Length: 2270, dtype: int64, n_r_peaks=2270), QRSResultType(r_peak_positions=0          409
1          697
2          988
3         1304
4         1613
         ...
1705    648639
1706    648930
1707    649243
1708    649553
1709    649851
Length: 1710, dtype: int64, n_r_peaks=1710), QRSResultType(r_peak_positions=0           17
1          314
2          613
3          899
4         1186
         ...
2061    648729
2062    649021
2063    649298
2064    649578
2065    649874
Length: 2066, dtype: int64, n_r_peaks=2066), QRSResultType(r_peak_positions=0          197
1          459
2          708
3          964
4         1221
         ...
2562    648733
2563    648977
2564    649221
2565    649471
2566    649740
Length: 2567, dtype: int64, n_r_peaks=2567), QRSResultType(r_peak_positions=0          351
1          725
2         1086
3         1448
4         1830
         ...
1699    648969
1700    649161
1701    649335
1702    649792
1703    649990
Length: 1704, dtype: int64, n_r_peaks=1704), QRSResultType(r_peak_positions=0      10875
1     168524
2     169689
3     170426
4     170802
       ...
73    343872
74    359503
75    361856
76    472918
77    526420
Length: 78, dtype: int64, n_r_peaks=78), QRSResultType(r_peak_positions=0     281594
1     281953
2     282291
3     299048
4     300134
5     300486
6     303833
7     304565
8     305674
9     306034
10    307475
11    314265
12    354268
13    469019
14    477999
15    512726
16    513064
17    513384
18    627927
19    629093
20    629709
21    630859
22    631156
23    636224
24    636519
25    636821
26    637118
27    637399
28    637652
29    638046
dtype: int64, n_r_peaks=30), QRSResultType(r_peak_positions=0           16
1          284
2          562
3          838
4         1105
         ...
2387    648934
2388    649192
2389    649444
2390    649703
2391    649958
Length: 2392, dtype: int64, n_r_peaks=2392), QRSResultType(r_peak_positions=0          309
1          504
2          977
3         1315
4         1651
         ...
1983    648792
1984    649129
1985    649468
1986    649788
1987    649985
Length: 1988, dtype: int64, n_r_peaks=1988), QRSResultType(r_peak_positions=0      1569
1     88217
2     92814
3    168263
4    301711
5    581676
dtype: int64, n_r_peaks=6), QRSResultType(r_peak_positions=0           71
1          551
2         1022
3         1499
4         1926
         ...
1513    648248
1514    648627
1515    648999
1516    649343
1517    649690
Length: 1518, dtype: int64, n_r_peaks=1518), QRSResultType(r_peak_positions=0          488
1          965
2         1434
3         1883
4         2332
         ...
1448    647546
1449    648357
1450    648629
1451    649409
1452    649928
Length: 1453, dtype: int64, n_r_peaks=1453)]

And the inputs are stored as well.

[ECGExampleData [1 groups/rows]

     patient_group participant
   0       group_1         100, ECGExampleData [1 groups/rows]

     patient_group participant
   0       group_2         102, ECGExampleData [1 groups/rows]

     patient_group participant
   0       group_3         104, ECGExampleData [1 groups/rows]

     patient_group participant
   0       group_1         105, ECGExampleData [1 groups/rows]

     patient_group participant
   0       group_2         106, ECGExampleData [1 groups/rows]

     patient_group participant
   0       group_3         108, ECGExampleData [1 groups/rows]

     patient_group participant
   0       group_1         114, ECGExampleData [1 groups/rows]

     patient_group participant
   0       group_2         116, ECGExampleData [1 groups/rows]

     patient_group participant
   0       group_3         119, ECGExampleData [1 groups/rows]

     patient_group participant
   0       group_1         121, ECGExampleData [1 groups/rows]

     patient_group participant
   0       group_2         123, ECGExampleData [1 groups/rows]

     patient_group participant
   0       group_3         200]

Custom Iterators#

When passing an iterable directly is not really convenient, you can also create a custom iterator class. This class can reimplement iterate with custom logic. For example, you could provide a custom iterator that takes a data and a sections parameter and then loops over the sections of the data.

For this we need to create a custom subclass inheriting from BaseTypedIterator.

from collections.abc import Iterator

from tpcp.misc import BaseTypedIterator


class SectionIterator(BaseTypedIterator[QRSResultType]):
    def iterate(self, data: pd.DataFrame, sections: pd.DataFrame) -> Iterator[tuple[pd.DataFrame, QRSResultType]]:
        # We turn the sections into a generator of dataframes
        data_iterable = (data.iloc[s.start : s.end] for s in sections.itertuples(index=False))
        # We use the `_iterate` method to do the heavy lifting
        yield from self._iterate(data_iterable)

We create some dummy data and sections to test the iterator.

dummy_data = pd.DataFrame({"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]})
dummy_sections = pd.DataFrame({"start": [0, 5], "end": [5, 10]})

Now we can use the iterator to iterate over the data. We skip any form of aggregation here, as it is not really relevant for this example, but it would work the same way as before.

@dataclass
class SimpleResultType:
    n_samples: int


custom_iterator = SectionIterator(SimpleResultType)

for d, r in custom_iterator.iterate(dummy_data, dummy_sections):
    print(d)
    r.n_samples = len(d)
   data
0     1
1     2
2     3
3     4
4     5
   data
5     6
6     7
7     8
8     9
9    10

We can see that the iterator iterated over the two sections of the data. And the raw results contain two instances of the result dataclass.

custom_iterator.raw_results_
[SimpleResultType(n_samples=5), SimpleResultType(n_samples=5)]
[5, 5]
[   data
0     1
1     2
2     3
3     4
4     5,    data
5     6
6     7
7     8
8     9
9    10]

Total running time of the script: (0 minutes 5.589 seconds)

Estimated memory usage: 9 MB

Gallery generated by Sphinx-Gallery