Grid Search optimal Algorithm Parameter#

In case no better way exists to optimize a parameter of a algorithm or pipeline an exhaustive Gridsearch might be a good idea. tpcp provides a Gridsearch that is algorithm agnostic (as long as you can wrap your algorithm into a pipeline).

As example, we are going to Gridsearch some parameters of the QRSDetector we implemented in Algorithms - A real world example: QRS-Detection.

To perform a GridSearch (or any other form of parameter optimization in Gaitmap), we first need to have a Dataset, a Pipeline and a score function.

1. The Dataset#

Datsets wrap multiple recordings into an easy-to-use interface that can be passed around between the higher level tpcp functions. Learn more about this here. If you are lucky, you do not need to create the dataset on your own, but someone has already created a dataset for the data you want to use.

Here, we’re just going to reuse the ECGExample dataset we created in Custom Dataset - A real world example.

For our GridSearch, we need an instance of this dataset.

from pathlib import Path

from examples.datasets.datasets_final_ecg import ECGExampleData

try:
    HERE = Path(__file__).parent
except NameError:
    HERE = Path().resolve()
data_path = HERE.parent.parent / "example_data/ecg_mit_bih_arrhythmia/data"
example_data = ECGExampleData(data_path)

1. The Pipeline#

The pipeline simply defines what algorithms we want to run on our data and defines, which parameters of the pipeline you still want to be able to modify (e.g. to optimize in the GridSearch).

The pipeline usually needs 3 things:

  1. It needs to be subclass of Pipeline.

  2. It needs to have a run method that runs all the algorithmic steps and stores the results as class attributes. The run method should expect only a single data point (in our case a single recording of one sensor) as input.

  3. A init that defines all parameters that should be adjustable. Note, that the names in the function signature of the init method, must match the corresponding attribute names (e.g. max_cost -> self.max_cost). If you want to adjust multiple parameters that all belong to the same algorithm (and your algorithm is implemented as a subclass of Algorithm, it can be convenient to just pass the algorithm as a parameter.

Here we simply extract the data and sampling rate from the datapoint and then run the algorithm. We store the final results we are interested in on the pipeline object.

For the final GridSearch, we need an instance of the pipeline object.

import pandas as pd

from examples.algorithms.algorithms_qrs_detection_final import QRSDetector
from tpcp import Parameter, Pipeline, cf


class MyPipeline(Pipeline[ECGExampleData]):
    algorithm: Parameter[QRSDetector]

    r_peak_positions_: pd.Series

    def __init__(self, algorithm: QRSDetector = cf(QRSDetector())):
        self.algorithm = algorithm

    def run(self, datapoint: ECGExampleData):
        # Note: We need to clone the algorithm instance, to make sure we don't leak any data between runs.
        algo = self.algorithm.clone()
        algo.detect(datapoint.data["ecg"], datapoint.sampling_rate_hz)

        self.r_peak_positions_ = algo.r_peak_positions_
        return self


pipe = MyPipeline()

3. The scorer#

In the context of a gridsearch, we want to calculate the performance of our algorithm and rank the different parameter candidates accordingly. This is what our score function is for. It gets a pipeline object (without results!) and a data point (i.e. a single recording) as input and should return a some sort of performance metric. A higher value is always considered better. If you want to calculate multiple performance measures, you can also return a dictionary of such values. In any case, the performance for a specific parameter combination in the GridSearch will be calculated as the mean over all datapoints. (Note, if you want to change this, you can create a custom Aggregator).

A typical score function will first call safe_run (which calls run internally) on the pipeline and then compare the output with some reference. This reference should be supplied as part of the dataset.

Instead of using a function as scorer (shown here), you can also implement a method called score on your pipeline. Then just pass None (which is the default) for the scoring parameter in the GridSearch (and other optimizers). However, a function is usually more flexible.

In this case we compare the identified R-peaks with the reference and identify which R-peaks were correctly found within a certain margin around the reference points Based on these matches, we calculate the precision, the recall, and the f1-score using some helper functions.

from examples.algorithms.algorithms_qrs_detection_final import match_events_with_reference, precision_recall_f1_score


def score(pipeline: MyPipeline, datapoint: ECGExampleData):
    # We use the `safe_run` wrapper instead of just run. This is always a good idea.
    # We don't need to clone the pipeline here, as GridSearch will already clone the pipeline internally and `run`
    # will clone it again.
    pipeline = pipeline.safe_run(datapoint)
    tolerance_s = 0.02  # We just use 20 ms for this example
    matches = match_events_with_reference(
        pipeline.r_peak_positions_.to_numpy(),
        datapoint.r_peak_positions_.to_numpy(),
        tolerance=tolerance_s * datapoint.sampling_rate_hz,
    )
    precision, recall, f1_score = precision_recall_f1_score(matches)
    return {"precision": precision, "recall": recall, "f1_score": f1_score}

The Parameters#

The last step before running the GridSearch, is to select the parameters we want to test for each dataset. For this, we can directly use sklearn’s ParameterGrid.

In this example, we will just test three values for the high_pass_filter_cutoff_hz. As this is a nested paramater, we use the __ syntax to set it.

from sklearn.model_selection import ParameterGrid

parameters = ParameterGrid({"algorithm__high_pass_filter_cutoff_hz": [0.25, 0.5, 1]})

Running the GridSearch#

Now we have all the pieces to run the GridSearch. After initializing, we can use optimize to run the GridSearch.

Note

If the score function returns a dictionary of scores, return_optimized must be set to the name of the score, that should be used to decide on the best parameter set.

from tpcp.optimize import GridSearch

gs = GridSearch(pipe, parameters, scoring=score, return_optimized="f1_score")
gs = gs.optimize(example_data)
Parameter Combinations:   0%|          | 0/3 [00:00<?, ?it/s]

Datapoints:   0%|          | 0/12 [00:00<?, ?it/s]

Datapoints:  17%|█▋        | 2/12 [00:00<00:00, 17.58it/s]

Datapoints:  33%|███▎      | 4/12 [00:00<00:00, 16.25it/s]

Datapoints:  50%|█████     | 6/12 [00:00<00:00, 14.89it/s]

Datapoints:  67%|██████▋   | 8/12 [00:00<00:00, 14.47it/s]

Datapoints:  83%|████████▎ | 10/12 [00:00<00:00, 14.32it/s]

Datapoints: 100%|██████████| 12/12 [00:00<00:00, 14.10it/s]
Datapoints: 100%|██████████| 12/12 [00:00<00:00, 14.56it/s]

Parameter Combinations:  33%|███▎      | 1/3 [00:00<00:01,  1.10it/s]

Datapoints:   0%|          | 0/12 [00:00<?, ?it/s]

Datapoints:  17%|█▋        | 2/12 [00:00<00:00, 14.74it/s]

Datapoints:  33%|███▎      | 4/12 [00:00<00:00, 14.59it/s]

Datapoints:  50%|█████     | 6/12 [00:00<00:00, 14.64it/s]

Datapoints:  67%|██████▋   | 8/12 [00:00<00:00, 14.65it/s]

Datapoints:  83%|████████▎ | 10/12 [00:00<00:00, 14.45it/s]

Datapoints: 100%|██████████| 12/12 [00:00<00:00, 14.46it/s]
Datapoints: 100%|██████████| 12/12 [00:00<00:00, 14.52it/s]

Parameter Combinations:  67%|██████▋   | 2/3 [00:01<00:00,  1.10it/s]

Datapoints:   0%|          | 0/12 [00:00<?, ?it/s]

Datapoints:  17%|█▋        | 2/12 [00:00<00:00, 14.51it/s]

Datapoints:  33%|███▎      | 4/12 [00:00<00:00, 14.34it/s]

Datapoints:  50%|█████     | 6/12 [00:00<00:00, 14.39it/s]

Datapoints:  67%|██████▋   | 8/12 [00:00<00:00, 14.24it/s]

Datapoints:  83%|████████▎ | 10/12 [00:00<00:00, 14.18it/s]

Datapoints: 100%|██████████| 12/12 [00:00<00:00, 14.28it/s]
Datapoints: 100%|██████████| 12/12 [00:00<00:00, 14.28it/s]

Parameter Combinations: 100%|██████████| 3/3 [00:02<00:00,  1.10it/s]
Parameter Combinations: 100%|██████████| 3/3 [00:02<00:00,  1.10it/s]

The main results are stored in gs_results_. It shows the mean performance per parameter combination, the rank for each parameter combination and the performance for each individual data point (in our case a single recording of one participant).

{'agg__precision': array([0.98204142, 0.98744415, 0.99293585]), 'rank__agg__precision': array([3, 2, 1], dtype=int32), 'agg__recall': array([0.67801805, 0.67721195, 0.67377553]), 'rank__agg__recall': array([1, 2, 3], dtype=int32), 'agg__f1_score': array([0.71986366, 0.71690064, 0.70897276]), 'rank__agg__f1_score': array([1, 2, 3], dtype=int32), 'single__precision': [[0.999559277214632, 0.9795454545454545, 0.9655172413793104, 0.977970102281668, 0.9760579064587973, 0.9485294117647058, 0.9772727272727273, 0.9962406015037594, 0.9989939637826962, 1.0, 0.9993416721527321, 0.9654686398872445], [0.9995598591549296, 0.9827298050139276, 0.9647239263803681, 0.9772905246671887, 0.9800683371298405, 1.0, 0.972972972972973, 0.9974916387959866, 0.9989939637826962, 1.0, 0.9993416721527321, 0.9761570827489481], [1.0, 0.9883040935672515, 0.9704743465634076, 0.9797428905336969, 0.9865023474178404, 1.0, 1.0, 0.9979096989966555, 0.9984909456740443, 1.0, 1.0, 0.993805918788713]], 'single__recall': [[0.9978002639683238, 0.788294467306813, 0.841633019291162, 0.9665629860031104, 0.8648248643315244, 0.07317073170731707, 0.022884513038850453, 0.9888059701492538, 0.9994967287367891, 0.06602254428341385, 1.0, 0.526720492118416], [0.9991201055873296, 0.8065843621399177, 0.8465679676985195, 0.9704510108864697, 0.8490379871731623, 0.07146908678389109, 0.019159127195316657, 0.9892205638474295, 0.9994967287367891, 0.040257648953301126, 1.0, 0.5351787773933102], [0.9986801583809943, 0.772748056698674, 0.8995065051592642, 0.9778382581648523, 0.8293043907252097, 0.04424276800907544, 0.015965939329430547, 0.9896351575456053, 0.9989934574735783, 0.00322061191626409, 1.0, 0.5551710880430604]], 'single__f1_score': [[0.9986789960369881, 0.8735748669875855, 0.8993288590604027, 0.9722330856472429, 0.9170808265759873, 0.13586097946287518, 0.04472178887155486, 0.9925093632958802, 0.9992452830188681, 0.12386706948640483, 0.9996707276918012, 0.681592039800995], [0.9993399339933994, 0.88598694123556, 0.9017921146953405, 0.9738587592664845, 0.9098598995506212, 0.13340391741662253, 0.037578288100208766, 0.993338884263114, 0.9992452830188681, 0.07739938080495357, 0.9996707276918012, 0.6913334988825429], [0.9993396434074401, 0.8673338465486272, 0.9336437718277065, 0.9787896477913991, 0.9010989010989011, 0.08473655621944595, 0.03143006809848088, 0.9937552039966694, 0.99874213836478, 0.006420545746388443, 1.0, 0.7123828317710903]], 'data_labels': [[ECGExampleDataGroupLabel(patient_group='group_1', participant='100'), ECGExampleDataGroupLabel(patient_group='group_2', participant='102'), ECGExampleDataGroupLabel(patient_group='group_3', participant='104'), ECGExampleDataGroupLabel(patient_group='group_1', participant='105'), ECGExampleDataGroupLabel(patient_group='group_2', participant='106'), ECGExampleDataGroupLabel(patient_group='group_3', participant='108'), ECGExampleDataGroupLabel(patient_group='group_1', participant='114'), ECGExampleDataGroupLabel(patient_group='group_2', participant='116'), ECGExampleDataGroupLabel(patient_group='group_3', participant='119'), ECGExampleDataGroupLabel(patient_group='group_1', participant='121'), ECGExampleDataGroupLabel(patient_group='group_2', participant='123'), ECGExampleDataGroupLabel(patient_group='group_3', participant='200')], [ECGExampleDataGroupLabel(patient_group='group_1', participant='100'), ECGExampleDataGroupLabel(patient_group='group_2', participant='102'), ECGExampleDataGroupLabel(patient_group='group_3', participant='104'), ECGExampleDataGroupLabel(patient_group='group_1', participant='105'), ECGExampleDataGroupLabel(patient_group='group_2', participant='106'), ECGExampleDataGroupLabel(patient_group='group_3', participant='108'), ECGExampleDataGroupLabel(patient_group='group_1', participant='114'), ECGExampleDataGroupLabel(patient_group='group_2', participant='116'), ECGExampleDataGroupLabel(patient_group='group_3', participant='119'), ECGExampleDataGroupLabel(patient_group='group_1', participant='121'), ECGExampleDataGroupLabel(patient_group='group_2', participant='123'), ECGExampleDataGroupLabel(patient_group='group_3', participant='200')], [ECGExampleDataGroupLabel(patient_group='group_1', participant='100'), ECGExampleDataGroupLabel(patient_group='group_2', participant='102'), ECGExampleDataGroupLabel(patient_group='group_3', participant='104'), ECGExampleDataGroupLabel(patient_group='group_1', participant='105'), ECGExampleDataGroupLabel(patient_group='group_2', participant='106'), ECGExampleDataGroupLabel(patient_group='group_3', participant='108'), ECGExampleDataGroupLabel(patient_group='group_1', participant='114'), ECGExampleDataGroupLabel(patient_group='group_2', participant='116'), ECGExampleDataGroupLabel(patient_group='group_3', participant='119'), ECGExampleDataGroupLabel(patient_group='group_1', participant='121'), ECGExampleDataGroupLabel(patient_group='group_2', participant='123'), ECGExampleDataGroupLabel(patient_group='group_3', participant='200')]], 'debug__score_time': array([0.90364861, 0.90166974, 0.91373897]), 'param__algorithm__high_pass_filter_cutoff_hz': masked_array(data=[0.25, 0.5, 1],
             mask=[False, False, False],
       fill_value='?',
            dtype=object), 'params': [{'algorithm__high_pass_filter_cutoff_hz': 0.25}, {'algorithm__high_pass_filter_cutoff_hz': 0.5}, {'algorithm__high_pass_filter_cutoff_hz': 1}]}
agg__precision rank__agg__precision agg__recall rank__agg__recall agg__f1_score rank__agg__f1_score single__precision single__recall single__f1_score data_labels debug__score_time param__algorithm__high_pass_filter_cutoff_hz params
0 0.982041 3 0.678018 1 0.719864 1 [0.999559277214632, 0.9795454545454545, 0.9655... [0.9978002639683238, 0.788294467306813, 0.8416... [0.9986789960369881, 0.8735748669875855, 0.899... [(group_1, 100), (group_2, 102), (group_3, 104... 0.903649 0.25 {'algorithm__high_pass_filter_cutoff_hz': 0.25}
1 0.987444 2 0.677212 2 0.716901 2 [0.9995598591549296, 0.9827298050139276, 0.964... [0.9991201055873296, 0.8065843621399177, 0.846... [0.9993399339933994, 0.88598694123556, 0.90179... [(group_1, 100), (group_2, 102), (group_3, 104... 0.901670 0.5 {'algorithm__high_pass_filter_cutoff_hz': 0.5}
2 0.992936 1 0.673776 3 0.708973 3 [1.0, 0.9883040935672515, 0.9704743465634076, ... [0.9986801583809943, 0.772748056698674, 0.8995... [0.9993396434074401, 0.8673338465486272, 0.933... [(group_1, 100), (group_2, 102), (group_3, 104... 0.913739 1 {'algorithm__high_pass_filter_cutoff_hz': 1}


Further, the optimized_pipeline_ parameter holds an instance of the pipeline initialized with the best parameter combination.

print("Best Para Combi:", gs.best_params_)
print("Paras of optimized Pipeline:", gs.optimized_pipeline_.get_params())
Best Para Combi: {'algorithm__high_pass_filter_cutoff_hz': 0.25}
Paras of optimized Pipeline: {'algorithm__high_pass_filter_cutoff_hz': 0.25, 'algorithm__max_heart_rate_bpm': 200.0, 'algorithm__min_r_peak_height_over_baseline': 1.0, 'algorithm': QRSDetector(high_pass_filter_cutoff_hz=0.25, max_heart_rate_bpm=200.0, min_r_peak_height_over_baseline=1.0)}

To run the optimized pipeline, we can directly use the run/safe_run method on the GridSearch object. This makes it possible to use the GridSearch as a replacement for your pipeline object with minimal code changes.

If you tried to call run/safe_run (or score for that matter), before the optimization, an error is raised.

r_peaks = gs.safe_run(example_data[0]).r_peak_positions_
r_peaks
0           77
1          370
2          663
3          947
4         1231
         ...
2264    648733
2265    648978
2266    649232
2267    649485
2268    649991
Length: 2269, dtype: int64

Total running time of the script: (0 minutes 3.957 seconds)

Estimated memory usage: 59 MB

Gallery generated by Sphinx-Gallery