TypedIterator#

This example shows how to use the TypedIterator class, which might be helpful, when iterating over data and needing to store multiple results for each iteration.

The Problem#

A very common pattern when working with any type of data is to iterate over it and then apply a series of operations to it. In simple cases you might only want to store the final result, but often you are also interested in intermediate or alternative outputs.

What typically happens, is that you create multiple empty lists or dictionaries (one for each result) and then append the results to them during the iteration. At the end you might apply further operations to the results, e.g. aggregations.

Below is a simple example of this pattern:

45
[6, 12, 18, 24, 30]
[2, 8, 14, 20, 26]

Fundamentally, this pattern works well. However, it does not really fit into the idea of declarative code that we are trying to achieve with tpcp. While programming, there are 3 places where you need to think about the result and the result types. This makes it harder to reason about the code and also makes it harder to change the code later on. In addition, the main pipeline code, which should be the most important part of the code, is cluttered with boilerplate code concerned with just storing the results.

While we could fix some of these issues by refactoring a little, with TypedIterator we provide (in our opinion) a much cleaner solution.

The basic idea of TypedIterator is to provide a way to specify all configuration (i.e. what results to expect and how to aggregate them) in one place at the beginning. It further simplifies how to store results, by inverting the data structure. Instead of worrying about one data structure for each result, you only need to worry about one data structure for each iteration. Using dataclasses, these objects are also typed, preventing typos and providing IDE support.

Let’s rewrite the above example using TypedIterator:

  1. We define our result-datatype as a dataclass.

from dataclasses import dataclass


@dataclass
class ResultType:
    result_1: int
    result_2: int
    result_3: int
  1. We define the aggregations we want to apply to the results. If we don’t want to aggregate a result, we simply don’t add it to the list. We provide some more explanation on aggregations below, just accept this for now.

from tpcp.misc import TypedIteratorResultTuple


def sum_agg(results: list[TypedIteratorResultTuple[int, ResultType]]):
    return sum(r.result.result_1 for r in results)


aggregations = [
    ("result_1", sum_agg),
]

3. We create a new instance of TypedIterator with the result type and the aggregations. We use the “square bracket” typing syntax to bind the output datatype and the input datatype we are planning to iterate over. This way, our IDE is able to autocomplete the attributes of the result type.

from tpcp.misc import TypedIterator

iterator = TypedIterator[int, ResultType](ResultType, aggregations=aggregations)

Now we can iterate over our data and get a result object for each iteration, that we can then fill with the results.

for d, r in iterator.iterate(data):
    r.result_1 = d * 3
    r.result_2 = r.result_1 * 2
    r.result_3 = r.result_2 - 4

You can access the data using the results_ attribute.

ResultType(result_1=45, result_2=[6, 12, 18, 24, 30], result_3=[2, 8, 14, 20, 26])

Your IDE should be able to autocomplete the attributes.

45

The raw results are available as a list of Result tuples. They allow us to access the results in the order they were created, and contain further metadata like the input data.

iterator.raw_results_
[TypedIteratorResultTuple(iteration_name='__main__', input=1, result=ResultType(result_1=3, result_2=6, result_3=2), iteration_context={}), TypedIteratorResultTuple(iteration_name='__main__', input=2, result=ResultType(result_1=6, result_2=12, result_3=8), iteration_context={}), TypedIteratorResultTuple(iteration_name='__main__', input=3, result=ResultType(result_1=9, result_2=18, result_3=14), iteration_context={}), TypedIteratorResultTuple(iteration_name='__main__', input=4, result=ResultType(result_1=12, result_2=24, result_3=20), iteration_context={}), TypedIteratorResultTuple(iteration_name='__main__', input=5, result=ResultType(result_1=15, result_2=30, result_3=26), iteration_context={})]

While this version of the code required a couple more lines, it is much easier to understand and reason about. It clearly separates the configuration from the actual code and the core pipeline code is much cleaner.

A real-world example#

Below we apply this pattern to a pipeline that iterates over an actual dataset. The return types are a little bit more complex to show some more advanced features of aggregations.

For this example we apply the QRS detection algorithm to the ECG dataset demonstrated in some of the other examples. The QRS detection algorithm only has a single output. Hence, we use the “number of r-peaks” as a second result here to demonstrate the use case.

Again we start by defining the result dataclass.

import pandas as pd


@dataclass
class QRSResultType:
    """The result type of the QRS detection algorithm."""

    r_peak_positions: pd.Series
    n_r_peaks: int

Our input data is going to be a dataset object of the ECGExampleData type.

from pathlib import Path

from examples.datasets.datasets_final_ecg import ECGExampleData

try:
    HERE = Path(__file__).parent
except NameError:
    HERE = Path().resolve()
data_path = HERE.parent.parent / "example_data/ecg_mit_bih_arrhythmia/data"

dataset = ECGExampleData(data_path)

For the aggregations, we want to concatenate the r-peak positions. The aggregation function gets all raw results as input. So it can access all inputs, all results, and all metadata. This means you can define any aggregation you want. In this case, we want to concatenate the r-peak positions into a single dataframe. And we turn the n_r_peaks into a dictionary, to make it easier to map the results back to the inputs.

Note that we can type these functions using the TypedIteratorResultTuple type. Like the iterator itself, this type is generic and allows you to specify the input and output types. So in our case, the input is ECGExampleData and the output is QRSResultType.

from typing_extensions import TypeAlias

from tpcp.misc import TypedIteratorResultTuple

result_tup: TypeAlias = TypedIteratorResultTuple[ECGExampleData, QRSResultType]


def concat_r_peak_positions(results: list[result_tup]):
    return pd.concat({r.input.group_label: r.result.r_peak_positions for r in results})


def aggregate_n_r_peaks(results: list[result_tup]):
    return {r.input.group_label: r.result.n_r_peaks for r in results}


aggregations = [
    ("r_peak_positions", concat_r_peak_positions),
    ("n_r_peaks", aggregate_n_r_peaks),
]

Now we can create the iterator and iterate over the dataset. The iterator takes the same type parameters as our result-tuple.

We can then iterate over the dataset and apply the QRS detection algorithm.

from examples.algorithms.algorithms_qrs_detection_final import QRSDetector

qrs_iterator = TypedIterator[ECGExampleData, QRSResultType](QRSResultType, aggregations=aggregations)

for d, r in qrs_iterator.iterate(dataset):
    r.r_peak_positions = QRSDetector().detect(d.data["ecg"], sampling_rate_hz=d.sampling_rate_hz).r_peak_positions_
    r.n_r_peaks = len(r.r_peak_positions)

Finally we can inspect the results stored on the iterator.

QRSResultType(r_peak_positions=group_1  100  0           77
              1          370
              2          663
              3          947
              4         1231
                       ...
group_3  200  1448    647546
              1449    648357
              1450    648629
              1451    649409
              1452    649928
Length: 17782, dtype: int64, n_r_peaks={ECGExampleDataGroupLabel(patient_group='group_1', participant='100'): 2270, ECGExampleDataGroupLabel(patient_group='group_2', participant='102'): 1710, ECGExampleDataGroupLabel(patient_group='group_3', participant='104'): 2066, ECGExampleDataGroupLabel(patient_group='group_1', participant='105'): 2567, ECGExampleDataGroupLabel(patient_group='group_2', participant='106'): 1704, ECGExampleDataGroupLabel(patient_group='group_3', participant='108'): 78, ECGExampleDataGroupLabel(patient_group='group_1', participant='114'): 30, ECGExampleDataGroupLabel(patient_group='group_2', participant='116'): 2392, ECGExampleDataGroupLabel(patient_group='group_3', participant='119'): 1988, ECGExampleDataGroupLabel(patient_group='group_1', participant='121'): 6, ECGExampleDataGroupLabel(patient_group='group_2', participant='123'): 1518, ECGExampleDataGroupLabel(patient_group='group_3', participant='200'): 1453})

Note, that r_peak_positions_ is a single dataframe now and not a list of dataframes.

group_1  100  0           77
              1          370
              2          663
              3          947
              4         1231
                       ...
group_3  200  1448    647546
              1449    648357
              1450    648629
              1451    649409
              1452    649928
Length: 17782, dtype: int64

The n_r_peaks_ is still a dictionary, as expected.

{ECGExampleDataGroupLabel(patient_group='group_1', participant='100'): 2270, ECGExampleDataGroupLabel(patient_group='group_2', participant='102'): 1710, ECGExampleDataGroupLabel(patient_group='group_3', participant='104'): 2066, ECGExampleDataGroupLabel(patient_group='group_1', participant='105'): 2567, ECGExampleDataGroupLabel(patient_group='group_2', participant='106'): 1704, ECGExampleDataGroupLabel(patient_group='group_3', participant='108'): 78, ECGExampleDataGroupLabel(patient_group='group_1', participant='114'): 30, ECGExampleDataGroupLabel(patient_group='group_2', participant='116'): 2392, ECGExampleDataGroupLabel(patient_group='group_3', participant='119'): 1988, ECGExampleDataGroupLabel(patient_group='group_1', participant='121'): 6, ECGExampleDataGroupLabel(patient_group='group_2', participant='123'): 1518, ECGExampleDataGroupLabel(patient_group='group_3', participant='200'): 1453}

The raw results are still available.

qrs_iterator.raw_results_
[TypedIteratorResultTuple(iteration_name='__main__', input=ECGExampleData [1 groups/rows]

     patient_group participant
   0       group_1         100, result=QRSResultType(r_peak_positions=0           77
1          370
2          663
3          947
4         1231
         ...
2265    648978
2266    649232
2267    649485
2268    649734
2269    649992
Length: 2270, dtype: int64, n_r_peaks=2270), iteration_context={}), TypedIteratorResultTuple(iteration_name='__main__', input=ECGExampleData [1 groups/rows]

     patient_group participant
   0       group_2         102, result=QRSResultType(r_peak_positions=0          409
1          697
2          988
3         1304
4         1613
         ...
1705    648639
1706    648930
1707    649243
1708    649553
1709    649851
Length: 1710, dtype: int64, n_r_peaks=1710), iteration_context={}), TypedIteratorResultTuple(iteration_name='__main__', input=ECGExampleData [1 groups/rows]

     patient_group participant
   0       group_3         104, result=QRSResultType(r_peak_positions=0           17
1          314
2          613
3          899
4         1186
         ...
2061    648729
2062    649021
2063    649298
2064    649578
2065    649874
Length: 2066, dtype: int64, n_r_peaks=2066), iteration_context={}), TypedIteratorResultTuple(iteration_name='__main__', input=ECGExampleData [1 groups/rows]

     patient_group participant
   0       group_1         105, result=QRSResultType(r_peak_positions=0          197
1          459
2          708
3          964
4         1221
         ...
2562    648733
2563    648977
2564    649221
2565    649471
2566    649740
Length: 2567, dtype: int64, n_r_peaks=2567), iteration_context={}), TypedIteratorResultTuple(iteration_name='__main__', input=ECGExampleData [1 groups/rows]

     patient_group participant
   0       group_2         106, result=QRSResultType(r_peak_positions=0          351
1          725
2         1086
3         1448
4         1830
         ...
1699    648969
1700    649161
1701    649335
1702    649792
1703    649990
Length: 1704, dtype: int64, n_r_peaks=1704), iteration_context={}), TypedIteratorResultTuple(iteration_name='__main__', input=ECGExampleData [1 groups/rows]

     patient_group participant
   0       group_3         108, result=QRSResultType(r_peak_positions=0      10875
1     168524
2     169689
3     170426
4     170802
       ...
73    343872
74    359503
75    361856
76    472918
77    526420
Length: 78, dtype: int64, n_r_peaks=78), iteration_context={}), TypedIteratorResultTuple(iteration_name='__main__', input=ECGExampleData [1 groups/rows]

     patient_group participant
   0       group_1         114, result=QRSResultType(r_peak_positions=0     281594
1     281953
2     282291
3     299048
4     300134
5     300486
6     303833
7     304565
8     305674
9     306034
10    307475
11    314265
12    354268
13    469019
14    477999
15    512726
16    513064
17    513384
18    627927
19    629093
20    629709
21    630859
22    631156
23    636224
24    636519
25    636821
26    637118
27    637399
28    637652
29    638046
dtype: int64, n_r_peaks=30), iteration_context={}), TypedIteratorResultTuple(iteration_name='__main__', input=ECGExampleData [1 groups/rows]

     patient_group participant
   0       group_2         116, result=QRSResultType(r_peak_positions=0           16
1          284
2          562
3          838
4         1105
         ...
2387    648934
2388    649192
2389    649444
2390    649703
2391    649958
Length: 2392, dtype: int64, n_r_peaks=2392), iteration_context={}), TypedIteratorResultTuple(iteration_name='__main__', input=ECGExampleData [1 groups/rows]

     patient_group participant
   0       group_3         119, result=QRSResultType(r_peak_positions=0          309
1          504
2          977
3         1315
4         1651
         ...
1983    648792
1984    649129
1985    649468
1986    649788
1987    649985
Length: 1988, dtype: int64, n_r_peaks=1988), iteration_context={}), TypedIteratorResultTuple(iteration_name='__main__', input=ECGExampleData [1 groups/rows]

     patient_group participant
   0       group_1         121, result=QRSResultType(r_peak_positions=0      1569
1     88217
2     92814
3    168263
4    301711
5    581676
dtype: int64, n_r_peaks=6), iteration_context={}), TypedIteratorResultTuple(iteration_name='__main__', input=ECGExampleData [1 groups/rows]

     patient_group participant
   0       group_2         123, result=QRSResultType(r_peak_positions=0           71
1          551
2         1022
3         1499
4         1926
         ...
1513    648248
1514    648627
1515    648999
1516    649343
1517    649690
Length: 1518, dtype: int64, n_r_peaks=1518), iteration_context={}), TypedIteratorResultTuple(iteration_name='__main__', input=ECGExampleData [1 groups/rows]

     patient_group participant
   0       group_3         200, result=QRSResultType(r_peak_positions=0          488
1          965
2         1434
3         1883
4         2332
         ...
1448    647546
1449    648357
1450    648629
1451    649409
1452    649928
Length: 1453, dtype: int64, n_r_peaks=1453), iteration_context={})]

Custom Iterators#

When passing an iterable directly is not really convenient, you can also create a custom iterator class. This class can reimplement iterate with custom logic. For example, you could provide a custom iterator that takes a data and a sections parameter and then loops over the sections of the data.

For this we need to create a custom subclass inheriting from BaseTypedIterator.

from collections.abc import Iterator
from typing import Generic, TypeVar

from tpcp.misc import BaseTypedIterator

CustomTypeT = TypeVar("CustomTypeT")


class SectionIterator(BaseTypedIterator[pd.DataFrame, CustomTypeT], Generic[CustomTypeT]):
    def iterate(self, data: pd.DataFrame, sections: pd.DataFrame) -> Iterator[tuple[pd.DataFrame, CustomTypeT]]:
        # We turn the sections into a generator of dataframes
        data_iterable = (data.iloc[s.start : s.end] for s in sections.itertuples(index=False))
        # We use the `_iterate` method to do the heavy lifting
        yield from self._iterate(data_iterable)

We create some dummy data and sections to test the iterator.

dummy_data = pd.DataFrame({"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]})
dummy_sections = pd.DataFrame({"start": [0, 5], "end": [5, 10]})

Now we can use the iterator to iterate over the data. We skip any form of aggregation here, as it is not really relevant for this example, but it would work the same way as before.

@dataclass
class SimpleResultType:
    n_samples: int


custom_iterator = SectionIterator[SimpleResultType](SimpleResultType)

for d, r in custom_iterator.iterate(dummy_data, dummy_sections):
    print(d)
    r.n_samples = len(d)
   data
0     1
1     2
2     3
3     4
4     5
   data
5     6
6     7
7     8
8     9
9    10

We can see that the iterator iterated over the two sections of the data. And the raw results contain two instances of the result dataclass.

custom_iterator.raw_results_
[TypedIteratorResultTuple(iteration_name='__main__', input=   data
0     1
1     2
2     3
3     4
4     5, result=SimpleResultType(n_samples=5), iteration_context={}), TypedIteratorResultTuple(iteration_name='__main__', input=   data
5     6
6     7
7     8
8     9
9    10, result=SimpleResultType(n_samples=5), iteration_context={})]
SimpleResultType(n_samples=[5, 5])
[5, 5]

Advanced Usacases#

For a really advanced use cases, check out mobgap GsIterator. This makes use of sub-iterations to allow to iterate and aggregate subregions of the data dynamically.

Additional Aggregators#

We allow to pass additional aggregators to the iterator that have names that are not part of the result type. This allows to perform additional aggregations. They work as before, but the aggregation results are not available on the result object, but rather as raw dictionary via the additional_results_ attribute. We show that below with the section iterator we defined above.

aggregations = [("sum_n_samples", lambda results: sum(r.result.n_samples for r in results))]

custom_iterator = SectionIterator[SimpleResultType](SimpleResultType, aggregations=aggregations)

for d, r in custom_iterator.iterate(dummy_data, dummy_sections):
    r.n_samples = len(d)

custom_iterator.results_
SimpleResultType(n_samples=[5, 5])
custom_iterator.additional_results_
{'sum_n_samples': 10}

Total running time of the script: (0 minutes 7.170 seconds)

Estimated memory usage: 15 MB

Gallery generated by Sphinx-Gallery