GridSearchCV#

When trying to optimize parameters for algorithms that have trainable components, it is required to perform the parameter search on a validation set (that is separate from the test set used for the final validation). Even better, is to use a cross validation for this step. In tpcp this can be done by using GridSearchCV.

This example explains how to use this method. To learn more about the concept, review the evaluation guide and the sklearn guide on tuning hyperparameters.

import random

import pandas as pd
from typing_extensions import Self

random.seed(1)  # We set the random seed for repeatable results

Dataset#

As always, we need a dataset, a pipeline, and a scoring method for a parameter search. Here, we’re just going to reuse the ECGExample dataset we created in Custom Dataset - A real world example.

from pathlib import Path

from examples.datasets.datasets_final_ecg import ECGExampleData

try:
    HERE = Path(__file__).parent
except NameError:
    HERE = Path().resolve()
data_path = HERE.parent.parent / "example_data/ecg_mit_bih_arrhythmia/data"
example_data = ECGExampleData(data_path)

from typing import Any

The Pipeline#

When using GridSearchCV our pipeline must be “optimizable”. Otherwise, we have no need for the CV part and could just use a simple gridsearch. Here we are going to create an optimizable pipeline that wraps the optimizable version of the QRS detector we developed in Algorithms - A real world example: QRS-Detection.

For more information about the pipeline below check our examples on Optimizable Pipelines. Todo: Full dedicated example for PureParameter

from examples.algorithms.algorithms_qrs_detection_final import OptimizableQrsDetector
from tpcp import OptimizableParameter, OptimizablePipeline, Parameter, cf


class MyPipeline(OptimizablePipeline[ECGExampleData]):
    algorithm: Parameter[OptimizableQrsDetector]
    algorithm__min_r_peak_height_over_baseline: OptimizableParameter[float]

    r_peak_positions_: pd.Series

    def __init__(self, algorithm: OptimizableQrsDetector = cf(OptimizableQrsDetector())):
        self.algorithm = algorithm

    def self_optimize(self, dataset: ECGExampleData, **kwargs: Any):
        ecg_data = [d.data["ecg"] for d in dataset]
        r_peaks = [d.r_peak_positions_["r_peak_position"] for d in dataset]
        # Note: We need to clone the algorithm instance, to make sure we don't leak any data between runs.
        algo = self.algorithm.clone()
        self.algorithm = algo.self_optimize(ecg_data, r_peaks, dataset.sampling_rate_hz)
        return self

    def run(self, datapoint: ECGExampleData) -> Self:
        # Note: We need to clone the algorithm instance, to make sure we don't leak any data between runs.
        algo = self.algorithm.clone()
        algo.detect(datapoint.data["ecg"], datapoint.sampling_rate_hz)

        self.r_peak_positions_ = algo.r_peak_positions_
        return self


pipe = MyPipeline()

The Scorer#

The scorer is identical to the scoring function used in the other examples. The F1-score is still the most important parameter for our comparison.

from examples.algorithms.algorithms_qrs_detection_final import match_events_with_reference, precision_recall_f1_score


def score(pipeline: MyPipeline, datapoint: ECGExampleData) -> dict[str, float]:
    # We use the `safe_run` wrapper instead of just run. This is always a good idea.
    # We don't need to clone the pipeline here, as GridSearch will already clone the pipeline internally and `run`
    # will clone it again.
    pipeline = pipeline.safe_run(datapoint)
    tolerance_s = 0.02  # We just use 20 ms for this example
    matches = match_events_with_reference(
        pipeline.r_peak_positions_.to_numpy(),
        datapoint.r_peak_positions_.to_numpy(),
        tolerance=tolerance_s * datapoint.sampling_rate_hz,
    )
    precision, recall, f1_score = precision_recall_f1_score(matches)
    return {"precision": precision, "recall": recall, "f1_score": f1_score}

Data Splitting#

Like with a normal cross validation, we need to decide on the number of folds and type of splits. In tpcp we support all cross validation iterators provided in sklearn.

To keep the runtime low for this example, we are going to use a 2-fold CV.

from sklearn.model_selection import KFold

cv = KFold(n_splits=2)

The Parameters#

The pipeline above exposes a couple of (nested) parameters. min_r_peak_height_over_baseline is the parameter we want to optimize. All other parameters are effectively hyper-parameters as they change the outcome of the optimization. We could differentiate further and say that only r_peak_match_tolerance_s is a true hyper parameter, as it only effects the outcome of the optimization, but the run method is independent from it. max_heart_rate_bpm and high_pass_filter_cutoff_hz effect both the optimization and run.

We could run the gridsearch over any combination of parameters. However, to keep things simple, we will only test a couple of values for high_pass_filter_cutoff_hz.

from sklearn.model_selection import ParameterGrid

parameters = ParameterGrid({"algorithm__high_pass_filter_cutoff_hz": [0.25, 0.5, 1]})

GridSearchCV#

Setting up the GridSearchCV object is similar to the normal GridSearch, we just need to add the additional cv parameter. Then we can simply run the search using the optimize method.

from tpcp.optimize import GridSearchCV

gs = GridSearchCV(pipeline=MyPipeline(), parameter_grid=parameters, scoring=score, cv=cv, return_optimized="f1_score")
gs = gs.optimize(example_data)
Split-Para Combos:   0%|          | 0/6 [00:00<?, ?it/s]

Datapoints:   0%|          | 0/6 [00:00<?, ?it/s]

Datapoints:  50%|█████     | 3/6 [00:00<00:00, 17.86it/s]

Datapoints:  83%|████████▎ | 5/6 [00:00<00:00, 15.92it/s]
Datapoints: 100%|██████████| 6/6 [00:00<00:00, 15.92it/s]

Split-Para Combos:  17%|█▋        | 1/6 [00:00<00:03,  1.27it/s]

Datapoints:   0%|          | 0/6 [00:00<?, ?it/s]

Datapoints:  33%|███▎      | 2/6 [00:00<00:00, 14.45it/s]

Datapoints:  67%|██████▋   | 4/6 [00:00<00:00, 14.91it/s]

Datapoints: 100%|██████████| 6/6 [00:00<00:00, 14.64it/s]
Datapoints: 100%|██████████| 6/6 [00:00<00:00, 14.65it/s]

Split-Para Combos:  33%|███▎      | 2/6 [00:01<00:03,  1.29it/s]

Datapoints:   0%|          | 0/6 [00:00<?, ?it/s]

Datapoints:  33%|███▎      | 2/6 [00:00<00:00, 15.48it/s]

Datapoints:  67%|██████▋   | 4/6 [00:00<00:00, 15.38it/s]

Datapoints: 100%|██████████| 6/6 [00:00<00:00, 15.24it/s]
Datapoints: 100%|██████████| 6/6 [00:00<00:00, 15.27it/s]

Split-Para Combos:  50%|█████     | 3/6 [00:02<00:02,  1.32it/s]

Datapoints:   0%|          | 0/6 [00:00<?, ?it/s]

Datapoints:  33%|███▎      | 2/6 [00:00<00:00, 15.43it/s]

Datapoints:  67%|██████▋   | 4/6 [00:00<00:00, 15.07it/s]

Datapoints: 100%|██████████| 6/6 [00:00<00:00, 14.82it/s]
Datapoints: 100%|██████████| 6/6 [00:00<00:00, 14.90it/s]

Split-Para Combos:  67%|██████▋   | 4/6 [00:03<00:01,  1.32it/s]

Datapoints:   0%|          | 0/6 [00:00<?, ?it/s]

Datapoints:  33%|███▎      | 2/6 [00:00<00:00, 15.34it/s]

Datapoints:  67%|██████▋   | 4/6 [00:00<00:00, 15.25it/s]

Datapoints: 100%|██████████| 6/6 [00:00<00:00, 15.34it/s]
Datapoints: 100%|██████████| 6/6 [00:00<00:00, 15.31it/s]

Split-Para Combos:  83%|████████▎ | 5/6 [00:03<00:00,  1.33it/s]

Datapoints:   0%|          | 0/6 [00:00<?, ?it/s]

Datapoints:  33%|███▎      | 2/6 [00:00<00:00, 15.40it/s]

Datapoints:  67%|██████▋   | 4/6 [00:00<00:00, 14.98it/s]

Datapoints: 100%|██████████| 6/6 [00:00<00:00, 15.17it/s]
Datapoints: 100%|██████████| 6/6 [00:00<00:00, 15.15it/s]

Split-Para Combos: 100%|██████████| 6/6 [00:04<00:00,  1.33it/s]
Split-Para Combos: 100%|██████████| 6/6 [00:04<00:00,  1.32it/s]

Results#

The output is also comparable to the output of the GridSearch. The main results are stored in the cv_results_ parameter. But instead of just a single performance value per parameter, we get one value per fold and the mean and std over all folds.

mean_optimize_time std_optimize_time mean_score_time std_score_time split0_test_data_labels split1_test_data_labels split0_train_data_labels split1_train_data_labels param_algorithm__high_pass_filter_cutoff_hz params split0_test_precision split1_test_precision mean_test_precision std_test_precision rank_test_precision split0_test_recall split1_test_recall mean_test_recall std_test_recall rank_test_recall split0_test_f1_score split1_test_f1_score mean_test_f1_score std_test_f1_score rank_test_f1_score split0_test_single_precision split1_test_single_precision split0_test_single_recall split1_test_single_recall split0_test_single_f1_score split1_test_single_f1_score
0 0.335774 0.024256 0.427677 0.015528 [(group_1, 100), (group_2, 102), (group_3, 104... [(group_1, 114), (group_2, 116), (group_3, 119... [(group_1, 114), (group_2, 116), (group_3, 119... [(group_1, 100), (group_2, 102), (group_3, 104... 0.25 {'algorithm__high_pass_filter_cutoff_hz': 0.25} 0.939253 0.936646 0.937949 0.001304 3 0.886065 0.800842 0.843453 0.042612 1 0.903974 0.824031 0.864003 0.039972 2 [0.9995600527936648, 0.9724391364262747, 0.961... [0.8974358974358975, 0.9954147561483951, 0.998... [0.9995600527936648, 0.9679926840420667, 0.967... [0.1490154337413518, 0.9900497512437811, 0.999... [0.9995600527936648, 0.9702108157653528, 0.964... [0.25559105431309903, 0.9927250051964249, 0.99...
1 0.308887 0.000973 0.431293 0.004875 [(group_1, 100), (group_2, 102), (group_3, 104... [(group_1, 114), (group_2, 116), (group_3, 119... [(group_1, 114), (group_2, 116), (group_3, 119... [(group_1, 100), (group_2, 102), (group_3, 104... 0.5 {'algorithm__high_pass_filter_cutoff_hz': 0.5} 0.951106 0.946568 0.948837 0.002269 2 0.881955 0.795663 0.838809 0.043146 3 0.904481 0.818777 0.861629 0.042852 3 [0.9995600527936648, 0.9722735674676525, 0.962... [0.9486166007905138, 0.9974947807933194, 0.998... [0.9995600527936648, 0.9620484682213077, 0.967... [0.12772751463544438, 0.9904643449419569, 0.99... [0.9995600527936648, 0.9671339921857045, 0.964... [0.225140712945591, 0.9939671312669024, 0.9992...
2 0.309644 0.001590 0.427927 0.002174 [(group_1, 100), (group_2, 102), (group_3, 104... [(group_1, 114), (group_2, 116), (group_3, 119... [(group_1, 114), (group_2, 116), (group_3, 119... [(group_1, 100), (group_2, 102), (group_3, 104... 1 {'algorithm__high_pass_filter_cutoff_hz': 1} 0.959503 0.946947 0.953225 0.006278 1 0.882437 0.796150 0.839294 0.043143 2 0.907905 0.823164 0.865534 0.042370 1 [0.9995600527936648, 0.9723119520073835, 0.962... [0.9228070175438596, 0.9974947807933194, 0.998... [0.9995600527936648, 0.9634202103337905, 0.968... [0.13996806812134113, 0.9904643449419569, 0.99... [0.9995600527936648, 0.9678456591639871, 0.965... [0.24306839186691312, 0.9939671312669024, 0.99...


The mean score is the primary parameter used to select the best parameter combi (if return_optimized is True). All other values performance values are just there to provide further insight.

results_df[["mean_test_precision", "mean_test_recall", "mean_test_f1_score"]]
mean_test_precision mean_test_recall mean_test_f1_score
0 0.937949 0.843453 0.864003
1 0.948837 0.838809 0.861629
2 0.953225 0.839294 0.865534


For even more insight, you can inspect the scores per datapoint:

results_df.filter(like="test_single")
split0_test_single_precision split1_test_single_precision split0_test_single_recall split1_test_single_recall split0_test_single_f1_score split1_test_single_f1_score
0 [0.9995600527936648, 0.9724391364262747, 0.961... [0.8974358974358975, 0.9954147561483951, 0.998... [0.9995600527936648, 0.9679926840420667, 0.967... [0.1490154337413518, 0.9900497512437811, 0.999... [0.9995600527936648, 0.9702108157653528, 0.964... [0.25559105431309903, 0.9927250051964249, 0.99...
1 [0.9995600527936648, 0.9722735674676525, 0.962... [0.9486166007905138, 0.9974947807933194, 0.998... [0.9995600527936648, 0.9620484682213077, 0.967... [0.12772751463544438, 0.9904643449419569, 0.99... [0.9995600527936648, 0.9671339921857045, 0.964... [0.225140712945591, 0.9939671312669024, 0.9992...
2 [0.9995600527936648, 0.9723119520073835, 0.962... [0.9228070175438596, 0.9974947807933194, 0.998... [0.9995600527936648, 0.9634202103337905, 0.968... [0.13996806812134113, 0.9904643449419569, 0.99... [0.9995600527936648, 0.9678456591639871, 0.965... [0.24306839186691312, 0.9939671312669024, 0.99...


If return_optimized was set to True (or the name of a score), a final optimization is performed using the best set of parameters and all the available data. The resulting pipeline will be stored in optimizable_pipeline_.

print("Best Para Combi:", gs.best_params_)
print("Paras of optimized Pipeline:", gs.optimized_pipeline_.get_params())
Best Para Combi: {'algorithm__high_pass_filter_cutoff_hz': 1}
Paras of optimized Pipeline: {'algorithm__high_pass_filter_cutoff_hz': 1, 'algorithm__max_heart_rate_bpm': 200.0, 'algorithm__min_r_peak_height_over_baseline': 0.6322168257130579, 'algorithm__r_peak_match_tolerance_s': 0.01, 'algorithm': OptimizableQrsDetector(high_pass_filter_cutoff_hz=1, max_heart_rate_bpm=200.0, min_r_peak_height_over_baseline=0.6322168257130579, r_peak_match_tolerance_s=0.01)}

To run the optimized pipeline, we can directly use the run/safe_run method on the GridSearchCV object. This makes it possible to use the GridSearchCV as a replacement for your pipeline object with minimal code changes.

If you tried to call run/safe_run (or score for that matter), before the optimization, an error is raised.

r_peaks = gs.safe_run(example_data[0]).r_peak_positions_
r_peaks
0           77
1          370
2          663
3          947
4         1231
         ...
2268    648978
2269    649232
2270    649485
2271    649734
2272    649992
Length: 2273, dtype: int64

Total running time of the script: (0 minutes 7.435 seconds)

Estimated memory usage: 38 MB

Gallery generated by Sphinx-Gallery