.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/validation/_04_advanced_cross_validation.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_validation__04_advanced_cross_validation.py: Advanced cross-validation ------------------------- In many real world datasets, a normal k-fold cross-validation might not be ideal, as it assumes that each data point is fully independent of each other. This is often not the case, as our dataset might contain multiple data points from the same participant. Furthermore, we might have multiple "stratification" variables that we want to keep balanced across the folds. For example, different clinical conditions or different measurement devices. This two concepts of "grouping" and "stratification" are sometimes complicated to understand and certain (even though common) cases are not supported by the standard sklearn cross-validation splitters, without "abusing" the API. For this reason, we create dedicated support for this in tpcp to tackle these cases with a little more confidence. .. GENERATED FROM PYTHON SOURCE LINES 16-20 Let's start by re-creating the simple example from the normal cross-validation example. Dataset +++++++ .. GENERATED FROM PYTHON SOURCE LINES 20-31 .. code-block:: default from pathlib import Path from examples.datasets.datasets_final_ecg import ECGExampleData try: HERE = Path(__file__).parent except NameError: HERE = Path().resolve() data_path = HERE.parent.parent / "example_data/ecg_mit_bih_arrhythmia/data" example_data = ECGExampleData(data_path) .. GENERATED FROM PYTHON SOURCE LINES 32-34 Pipeline ++++++++ .. GENERATED FROM PYTHON SOURCE LINES 34-57 .. code-block:: default import pandas as pd from tpcp import Parameter, Pipeline, cf from examples.algorithms.algorithms_qrs_detection_final import QRSDetector class MyPipeline(Pipeline): algorithm: Parameter[QRSDetector] r_peak_positions_: pd.Series def __init__(self, algorithm: QRSDetector = cf(QRSDetector())): self.algorithm = algorithm def run(self, datapoint: ECGExampleData): # Note: We need to clone the algorithm instance, to make sure we don't leak any data between runs. algo = self.algorithm.clone() algo.detect(datapoint.data, datapoint.sampling_rate_hz) self.r_peak_positions_ = algo.r_peak_positions_ return self .. GENERATED FROM PYTHON SOURCE LINES 58-60 The Scorer ++++++++++ .. GENERATED FROM PYTHON SOURCE LINES 60-81 .. code-block:: default from examples.algorithms.algorithms_qrs_detection_final import ( match_events_with_reference, precision_recall_f1_score, ) def score(pipeline: MyPipeline, datapoint: ECGExampleData): # We use the `safe_run` wrapper instead of just run. This is always a good idea. # We don't need to clone the pipeline here, as GridSearch will already clone the pipeline internally and `run` # will clone it again. pipeline = pipeline.safe_run(datapoint) tolerance_s = 0.02 # We just use 20 ms for this example matches = match_events_with_reference( pipeline.r_peak_positions_.to_numpy(), datapoint.r_peak_positions_.to_numpy(), tolerance=tolerance_s * datapoint.sampling_rate_hz, ) precision, recall, f1_score = precision_recall_f1_score(matches) return {"precision": precision, "recall": recall, "f1_score": f1_score} .. GENERATED FROM PYTHON SOURCE LINES 82-85 Stratifcation +++++++++++++ With this setup done, we can have a closer look at the dataset. .. GENERATED FROM PYTHON SOURCE LINES 85-87 .. code-block:: default example_data .. raw:: html

ECGExampleData [12 groups/rows]

patient_group participant
0 group_1 100
1 group_2 102
2 group_3 104
3 group_1 105
4 group_2 106
5 group_3 108
6 group_1 114
7 group_2 116
8 group_3 119
9 group_1 121
10 group_2 123
11 group_3 200


.. GENERATED FROM PYTHON SOURCE LINES 88-94 The index has two columns, one indicating the participant group and one indicating the participant id. In this simple example, all groups appear the same amount of times and the index is ordered in a way that each fold will likely get a balanced amount of participants from each group. To show the impact of grouping and stratification, we take a subset of the data, that removes some participants from "group_1" to create an imbalance. .. GENERATED FROM PYTHON SOURCE LINES 94-98 .. code-block:: default data_imbalanced = example_data.get_subset( index=example_data.index.query("participant not in ['114', '121']") ) .. GENERATED FROM PYTHON SOURCE LINES 99-102 Running a simple cross-validation with 2 folds, will have all group-1 participants in the test data of the first fold: Note, that we skip optimization of the pipeline, to keep the example simple and fast. .. GENERATED FROM PYTHON SOURCE LINES 102-116 .. code-block:: default from sklearn.model_selection import KFold from tpcp.optimize import DummyOptimize from tpcp.validate import cross_validate cv = KFold(n_splits=2) pipe = MyPipeline() optimizable_pipe = DummyOptimize(pipe) results = cross_validate( optimizable_pipe, data_imbalanced, scoring=score, cv=cv ) result_df = pd.DataFrame(results) .. rst-class:: sphx-glr-script-out .. code-block:: none CV Folds: 0%| | 0/2 [00:00` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: _04_advanced_cross_validation.ipynb <_04_advanced_cross_validation.ipynb>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_