.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/parameter_optimization/_05_optuna_search.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_parameter_optimization__05_optuna_search.py: .. _build_in_optuna_optimizer: Build-in Optuna Optimizers ========================== The :ref:`custom optuna example ` shows how to implement a specific optuna optimizer with full control over all aspects. This is still the recommended way to do things, as you often will have specific requirements for your objective function. However, there are still a number of problems that can be solved by a relative generic GridSearch or GridSearchCV. Therefore, we provide Optuna equivalents for these usecases to make use of the advanced samplers optuna provides. .. note:: We still recommend to read through the :ref:`custom optuna example ` before using the specific implementations demonstrated here. .. GENERATED FROM PYTHON SOURCE LINES 19-23 OptunaSearch - GridSearch on Steroids +++++++++++++++++++++++++++++++++++++ The `OptunaSearch` class can be used in all cases where you would use :class:`~tpcp.optimize.GridSearch`. The following is equivalent to the GridSearch example (:ref:`grid_search`). .. GENERATED FROM PYTHON SOURCE LINES 23-59 .. code-block:: default from pathlib import Path import pandas as pd from tpcp import Parameter, Pipeline, cf from examples.algorithms.algorithms_qrs_detection_final import QRSDetector from examples.datasets.datasets_final_ecg import ECGExampleData try: HERE = Path(__file__).parent except NameError: HERE = Path().resolve() data_path = HERE.parent.parent / "example_data/ecg_mit_bih_arrhythmia/data" example_data = ECGExampleData(data_path) class MyPipeline(Pipeline[ECGExampleData]): algorithm: Parameter[QRSDetector] r_peak_positions_: pd.Series def __init__(self, algorithm: QRSDetector = cf(QRSDetector())): self.algorithm = algorithm def run(self, datapoint: ECGExampleData): # Note: We need to clone the algorithm instance, to make sure we don't leak any data between runs. algo = self.algorithm.clone() algo.detect(datapoint.data["ecg"], datapoint.sampling_rate_hz) self.r_peak_positions_ = algo.r_peak_positions_ return self pipe = MyPipeline() .. GENERATED FROM PYTHON SOURCE LINES 60-68 Optuna Study ------------ To use optuna we need to create an optuna study, or rather a function that returns one, that can be used by `OptunaSearch` to create it. We will set this up identical to the :ref:`custom optuna example `. .. note:: We use a in-memory study here, if you want to use multiprocessing or ensure that your search can be continued, use a different study backend. .. GENERATED FROM PYTHON SOURCE LINES 68-77 .. code-block:: default from optuna import Trial, samplers def get_study_params(seed): # We use a simple RandomSampler, but every optuna sampler will work sampler = samplers.RandomSampler(seed=seed) return {"direction": "maximize", "sampler": sampler} .. GENERATED FROM PYTHON SOURCE LINES 78-83 Search Space ------------ In contrast to `GridSearch` where we define a fix parameter grid, in optuna we define a search space. Which value sin this search space will actually be evaluated depends on the chosen sampler. This also needs to be a function that takes the current trial object as input. .. GENERATED FROM PYTHON SOURCE LINES 83-92 .. code-block:: default def create_search_space(trial: Trial): trial.suggest_float( "algorithm__min_r_peak_height_over_baseline", 0.1, 2, step=0.1 ) trial.suggest_float( "algorithm__high_pass_filter_cutoff_hz", 0.1, 2, step=0.1 ) .. GENERATED FROM PYTHON SOURCE LINES 93-96 Score ----- We use the same scoring function as in the `GridSearch` example: .. GENERATED FROM PYTHON SOURCE LINES 96-116 .. code-block:: default from examples.algorithms.algorithms_qrs_detection_final import ( match_events_with_reference, precision_recall_f1_score, ) def score(pipeline: MyPipeline, datapoint: ECGExampleData): # We use the `safe_run` wrapper instead of just run. This is always a good idea. # We don't need to clone the pipeline here, as OptunaSearch will already clone the pipeline internally. pipeline = pipeline.safe_run(datapoint) tolerance_s = 0.02 # We just use 20 ms for this example matches = match_events_with_reference( pipeline.r_peak_positions_.to_numpy(), datapoint.r_peak_positions_.to_numpy(), tolerance=tolerance_s * datapoint.sampling_rate_hz, ) precision, recall, f1_score = precision_recall_f1_score(matches) return {"precision": precision, "recall": recall, "f1_score": f1_score} .. GENERATED FROM PYTHON SOURCE LINES 117-123 Running the search ------------------ Now we can run the search. Note, that because our scoring function returns a dictionary, we need to specify the key we want to optimize by passing it to `score_name`. In this case, we want to maximize the f1 score. .. GENERATED FROM PYTHON SOURCE LINES 123-136 .. code-block:: default from tpcp.optimize.optuna import OptunaSearch opti = OptunaSearch( pipe, get_study_params, create_search_space, scoring=score, n_trials=10, score_name="f1_score", random_seed=42, ) opti = opti.optimize(example_data) .. rst-class:: sphx-glr-script-out .. code-block:: none Datapoints: 0%| | 0/12 [00:00
datetime_start datetime_complete duration param__algorithm__high_pass_filter_cutoff_hz param__algorithm__min_r_peak_height_over_baseline state data_labels precision recall f1_score single__precision single__recall single__f1_score params
0 2024-10-24 10:07:55.331170 2024-10-24 10:07:56.104300 0 days 00:00:00.773130 2.0 0.8 COMPLETE [(group_1, 100), (group_2, 102), (group_3, 104... 0.975941 0.742356 0.778327 [1.0, 0.9739256397875422, 0.967756381549485, 0... [0.9995600527936648, 0.922267946959305, 0.9694... [0.9997799779977998, 0.9473931423203381, 0.968... {'algorithm__min_r_peak_height_over_baseline':...
1 2024-10-24 10:07:56.105094 2024-10-24 10:07:56.960224 0 days 00:00:00.855130 1.2 1.5 COMPLETE [(group_1, 100), (group_2, 102), (group_3, 104... 0.827515 0.308203 0.341468 [1.0, 1.0, 1.0, 0.9347826086956522, 0.99823943... [0.015398152221733392, 0.0004572473708276177, ... [0.030329289428076254, 0.0009140767824497258, ... {'algorithm__min_r_peak_height_over_baseline':...
2 2024-10-24 10:07:56.961051 2024-10-24 10:07:57.832541 0 days 00:00:00.871490 0.4 0.4 COMPLETE [(group_1, 100), (group_2, 102), (group_3, 104... 0.874715 0.869669 0.858757 [0.9995600527936648, 0.9711934156378601, 0.935... [0.9995600527936648, 0.9711934156378601, 0.967... [0.9995600527936648, 0.9711934156378601, 0.951... {'algorithm__min_r_peak_height_over_baseline':...
3 2024-10-24 10:07:57.833306 2024-10-24 10:07:58.707361 0 days 00:00:00.874055 1.8 0.2 COMPLETE [(group_1, 100), (group_2, 102), (group_3, 104... 0.787612 0.902061 0.829179 [0.9991204925241864, 0.9461024498886415, 0.907... [0.9995600527936648, 0.9711934156378601, 0.969... [0.9993402243237299, 0.9584837545126353, 0.937... {'algorithm__min_r_peak_height_over_baseline':...
4 2024-10-24 10:07:58.708121 2024-10-24 10:07:59.573288 0 days 00:00:00.865167 1.5 1.3 COMPLETE [(group_1, 100), (group_2, 102), (group_3, 104... 0.913540 0.428422 0.472669 [1.0, 1.0, 0.9914163090128756, 0.9884959522795... [0.3783545974483062, 0.002286236854138089, 0.2... [0.5489945738908394, 0.0045620437956204385, 0.... {'algorithm__min_r_peak_height_over_baseline':...
5 2024-10-24 10:07:59.574026 2024-10-24 10:08:00.469313 0 days 00:00:00.895287 2.0 0.1 COMPLETE [(group_1, 100), (group_2, 102), (group_3, 104... 0.628790 0.916112 0.735941 [0.9109863672814755, 0.49212233549582945, 0.54... [0.9995600527936648, 0.9711934156378601, 0.969... [0.9532200545416404, 0.6532369675534369, 0.699... {'algorithm__min_r_peak_height_over_baseline':...
6 2024-10-24 10:08:00.470068 2024-10-24 10:08:01.333343 0 days 00:00:00.863275 0.5 1.7 COMPLETE [(group_1, 100), (group_2, 102), (group_3, 104... 0.734090 0.287615 0.316558 [1.0, 0, 1.0, 0.8153846153846154, 0.9962157048... [0.0004399472063352398, 0, 0.05787348586810229... [0.0008795074758135445, 0, 0.10941475826972011... {'algorithm__min_r_peak_height_over_baseline':...
7 2024-10-24 10:08:01.334149 2024-10-24 10:08:02.213621 0 days 00:00:00.879472 0.4 0.4 COMPLETE [(group_1, 100), (group_2, 102), (group_3, 104... 0.874715 0.869669 0.858757 [0.9995600527936648, 0.9711934156378601, 0.935... [0.9995600527936648, 0.9711934156378601, 0.967... [0.9995600527936648, 0.9711934156378601, 0.951... {'algorithm__min_r_peak_height_over_baseline':...
8 2024-10-24 10:08:02.214397 2024-10-24 10:08:03.087060 0 days 00:00:00.872663 1.1 0.7 COMPLETE [(group_1, 100), (group_2, 102), (group_3, 104... 0.965966 0.818186 0.853251 [0.9995600527936648, 0.9717514124293786, 0.965... [0.9995600527936648, 0.943758573388203, 0.9681... [0.9995600527936648, 0.9575504523312457, 0.966... {'algorithm__min_r_peak_height_over_baseline':...
9 2024-10-24 10:08:03.087874 2024-10-24 10:08:03.959345 0 days 00:00:00.871471 0.6 0.9 COMPLETE [(group_1, 100), (group_2, 102), (group_3, 104... 0.984398 0.744233 0.786534 [0.9995600527936648, 0.9739130434782609, 0.966... [0.9995600527936648, 0.9218106995884774, 0.956... [0.9995600527936648, 0.9471458773784356, 0.961... {'algorithm__min_r_peak_height_over_baseline':...


.. GENERATED FROM PYTHON SOURCE LINES 147-149 We can also get the best para combi and an instance of the pipeline initialized with the best parameter combination. .. GENERATED FROM PYTHON SOURCE LINES 149-152 .. code-block:: default print("Best Para Combi:", opti.best_params_) print("Best score:", opti.best_score_) print("Paras of optimized Pipeline:", opti.optimized_pipeline_.get_params()) .. rst-class:: sphx-glr-script-out .. code-block:: none Best Para Combi: {'algorithm__min_r_peak_height_over_baseline': 0.4, 'algorithm__high_pass_filter_cutoff_hz': 0.4} Best score: 0.858757056619628 Paras of optimized Pipeline: {'algorithm__high_pass_filter_cutoff_hz': 0.4, 'algorithm__max_heart_rate_bpm': 200.0, 'algorithm__min_r_peak_height_over_baseline': 0.4, 'algorithm': QRSDetector(high_pass_filter_cutoff_hz=0.4, max_heart_rate_bpm=200.0, min_r_peak_height_over_baseline=0.4)} .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 9.313 seconds) **Estimated memory usage:** 28 MB .. _sphx_glr_download_auto_examples_parameter_optimization__05_optuna_search.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: _05_optuna_search.py <_05_optuna_search.py>` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: _05_optuna_search.ipynb <_05_optuna_search.ipynb>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_