{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n\n# Algorithms - A real world example: QRS-Detection\n\nIn this example we will implement a custom algorithm and discuss, when you might want to use an algorithm class over\njust pipelines.\n\nSpecifically we will implement a simple algorithm, designed to identify individual QRS complexes from a continuous\nECG signal.\nIf you have no idea what this all means, don't worry about it.\nSimply we want to find peaks in a continuous signal that has some artifacts.\n\n<div class=\"alert alert-danger\"><h4>Warning</h4><p>The algorithm we design is **not** a good algorithms! There are way better and properly evaluated\n             algorithms to do this job! Don't use this algorithms for anything :)</p></div>\n\n## When should you use custom algorithms?\nAlgorithms are a completely optional feature of tpcp and in many cases not required.\nHowever, algorithm subclasses provide a structured way to implement new algorithms when you don't have any better\nstructure to follow.\nFurther they allow the setting of nested parameters (e.g. when used as parameters to pipelines) and can benefit from\nother tooling in tpcp (e.g. cloning).\nFor more general information have a look at the general documentation page `datasets_algorithms_pipelines`.\n\n## Implementing QRS-Detection\nIn general our QRS-Detection will have two steps:\n\n1. High-pass filter the data to remove baseline drift. We will use a Butterworth filter for that.\n2. Apply a peak finding strategy to find the (hopefully dominant) R-peaks.\n   We will use :func:`~scipy.signal.find_peaks` with a couple of parameters for that.\n\nAs all algorithms, our algorithm needs to inherit from `tpcp.Algorithm` and implement an action method.\nIn our case we will call the action method `detect`, as it makes sense based on what the algorithm does.\nThis `detect` method will first do the filtering and then the peak search, which we will split into two methods to keep\nthings easier to understand.\n\nIf you just want the final implementation, without all the explanation, check\n`custom_algorithms_qrs_detection_final`.\n\nOk that is still a bunch of code... But let's focus on the aspects that are important in general:\n\n1. We inherit from `Algorithm`\n2. We get and define all parameters in the init without modification\n3. We define the name of out action method using `_action_method = \"detect\"`\n4. After we do the computations, we set the results on the instance\n5. We return self\n6. (Optionally) we applied the :func:`~tpcp.make_action_safe` decorator to our action method, which makes some runtimes\n   checks to ensure our implementation follows the tpcp spec.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from typing import List\n\nimport numpy as np\nimport pandas as pd\nfrom scipy import signal\n\nfrom tpcp import Algorithm, Parameter, make_action_safe\n\n\nclass QRSDetector(Algorithm):\n    _action_methods = \"detect\"\n\n    # Input Parameters\n    high_pass_filter_cutoff_hz: Parameter[float]\n    max_heart_rate_bpm: Parameter[float]\n    min_r_peak_height_over_baseline: Parameter[float]\n\n    # Results\n    r_peak_positions_: pd.Series\n\n    # Some internal constants\n    _HIGH_PASS_FILTER_ORDER: int = 4\n\n    def __init__(\n        self,\n        max_heart_rate_bpm: float = 200.0,\n        min_r_peak_height_over_baseline: float = 1.0,\n        high_pass_filter_cutoff_hz: float = 0.5,\n    ):\n        self.max_heart_rate_bpm = max_heart_rate_bpm\n        self.min_r_peak_height_over_baseline = min_r_peak_height_over_baseline\n        self.high_pass_filter_cutoff_hz = high_pass_filter_cutoff_hz\n\n    @make_action_safe\n    def detect(self, single_channel_ecg: pd.Series, sampling_rate_hz: float):\n        ecg = single_channel_ecg.to_numpy().flatten()\n\n        filtered_signal = self._filter(ecg, sampling_rate_hz)\n        peak_positions = self._search_strategy(filtered_signal, sampling_rate_hz)\n\n        self.r_peak_positions_ = pd.Series(peak_positions)\n        return self\n\n    def _search_strategy(\n        self, filtered_signal: np.ndarray, sampling_rate_hz: float, use_height: bool = True\n    ) -> np.ndarray:\n        # Calculate the minimal distance based on the expected heart rate\n        min_distance_between_peaks = 1 / (self.max_heart_rate_bpm / 60) * sampling_rate_hz\n\n        height = None\n        if use_height:\n            height = self.min_r_peak_height_over_baseline\n        peaks, _ = signal.find_peaks(filtered_signal, distance=min_distance_between_peaks, height=height)\n        return peaks\n\n    def _filter(self, ecg_signal: np.ndarray, sampling_rate_hz: float) -> np.ndarray:\n        sos = signal.butter(\n            btype=\"high\",\n            N=self._HIGH_PASS_FILTER_ORDER,\n            Wn=self.high_pass_filter_cutoff_hz,\n            output=\"sos\",\n            fs=sampling_rate_hz,\n        )\n        return signal.sosfiltfilt(sos, ecg_signal)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Testing the implementation\nTo test the implementation, we load our example ECG data using the dataset created in a previous example.\n\nBased on the simple test we can see that our algorithm works (at least for this piece of data).\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from pathlib import Path\n\nfrom examples.datasets.datasets_final_ecg import ECGExampleData\n\n# Loading the data\ntry:\n    HERE = Path(__file__).parent\nexcept NameError:\n    HERE = Path(\".\").resolve()\ndata_path = HERE.parent.parent / \"example_data/ecg_mit_bih_arrhythmia/data\"\nexample_data = ECGExampleData(data_path)\necg_data = example_data[0].data[\"ecg\"]\n\n# Initialize the algorithm\nalgorithm = QRSDetector()\nalgorithm = algorithm.detect(ecg_data, example_data.sampling_rate_hz)\n\n# Visualize the results\nimport matplotlib.pyplot as plt\n\nplt.figure()\nplt.plot(ecg_data[:5000])\nsubset_peaks = algorithm.r_peak_positions_[algorithm.r_peak_positions_ < 5000.0]\nplt.plot(subset_peaks, ecg_data[subset_peaks], \"s\")\nplt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Making the algorithm trainable\nThe implementation so far heavily depends on the value of the `min_r_peak_height_over_baseline` parameter.\nIf this is set incorrectly, everything will go wrong.\nThis parameter describes the minimal expected value of the filtered signal at the position of an R-peak.\nWithout looking at the filtered data, this value is hard to guess.\nThe value will depend on potential preprocessing applied to the data and the measurement conditions.\nBut we should be able to calculate a suitable value based on some training data (with R-peak annotations) recorded\nunder similar conditions.\n\nTherefore, we will create a second implementation of our algorithm that is *trainable*.\nMeaning, we will implement a method (`self_optimize`) that is able to estimate a suitable value for our cutoff\nbased on some training data.\nNote, that we do not provide a generic base class for optimizable algorithms.\nIf you need one, create your own class with a call signature for `self_optimize` that makes sense for the group of\nalgorithms you are trying to implement.\n\nFrom an implementation perspective, this means that we need to do the following things:\n\n1. Implement a `self_optimize` method that takes the data of multiple recordings including the reference labels to\n   calculate a suitable threshold. This method should modify only parameters marked as `OptimizableParameter` and then\n   return `self`.\n2. We need to mark the parameters that we want to optimize as `OptimizableParameter` using the type annotations on\n   the class level.\n3. We introduce a new parameter called `r_peak_match_tolerance_s` that is used by our `self_optimize` method.\n   Changing it, changes the output of our optimization.\n   Therefore, it is a Hyper-Parameter of our method.\n   We mark it as such using the type-hints on class level.\n5. (Optional) Wrap the `self_optimize` method with the :func:`~tpcp.make_optimize_safe` decorator. It will perform\n   some runtime checks and inform us, if we did not implement `self_optimize` as expected.\n\n<div class=\"alert alert-info\"><h4>Note</h4><p>The process required to implement an optimizable algorith will always be very similar to what we did\n          here.\n          It doesn't matter, if the optimization only optimizes a threshold or trains a neuronal network.\n          The structure will be very similar.</p></div>\n\nFrom a scientific perspective, we optimize our parameter by trying to find all R-peaks without a height restriction\nfirst.\nBased on the detected R-peaks, we determine, which of them are actually correctly detected, by checking if they are\nwithin the threshold `r_peak_match_tolerance_s` of a reference R-peak.\nThen we find the best height threshold to maximise our predictive power within these preliminary detected peaks.\n\nAgain, there are probably better ways to do it... But this is just an example, and we already have way too much code\nthat is not relevant for you to understand the basics of Algorithms.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from sklearn.metrics import roc_curve\n\nfrom examples.algorithms.algorithms_qrs_detection_final import match_events_with_reference\nfrom tpcp import HyperParameter, OptimizableParameter, make_optimize_safe\n\n\nclass OptimizableQrsDetector(QRSDetector):\n    min_r_peak_height_over_baseline: OptimizableParameter[float]\n    r_peak_match_tolerance_s: HyperParameter[float]\n\n    def __init__(\n        self,\n        max_heart_rate_bpm: float = 200.0,\n        min_r_peak_height_over_baseline: float = 1.0,\n        r_peak_match_tolerance_s: float = 0.01,\n        high_pass_filter_cutoff_hz: float = 1,\n    ):\n        self.r_peak_match_tolerance_s = r_peak_match_tolerance_s\n        super().__init__(\n            max_heart_rate_bpm=max_heart_rate_bpm,\n            min_r_peak_height_over_baseline=min_r_peak_height_over_baseline,\n            high_pass_filter_cutoff_hz=high_pass_filter_cutoff_hz,\n        )\n\n    @make_optimize_safe\n    def self_optimize(self, ecg_data: List[pd.Series], r_peaks: List[pd.Series], sampling_rate_hz: float):\n        all_labels = []\n        all_peak_heights = []\n        for d, p in zip(ecg_data, r_peaks):\n            filtered = self._filter(d.to_numpy().flatten(), sampling_rate_hz)\n            # Find all potential peaks without the height threshold\n            potential_peaks = self._search_strategy(filtered, sampling_rate_hz, use_height=False)\n            # Determine the label for each peak, by matching them with our ground truth\n            labels = np.zeros(potential_peaks.shape)\n            matches = match_events_with_reference(\n                events=potential_peaks,\n                reference=p.to_numpy().astype(int),\n                tolerance=self.r_peak_match_tolerance_s * sampling_rate_hz,\n            )\n            tp_matches = matches[(~np.isnan(matches)).all(axis=1), 0].astype(int)\n            labels[tp_matches] = 1\n            labels = labels.astype(bool)\n            all_labels.append(labels)\n            all_peak_heights.append(filtered[potential_peaks])\n        all_labels = np.hstack(all_labels)\n        all_peak_heights = np.hstack(all_peak_heights)\n        # We \"brute-force\" a good cutoff by testing a bunch of thresholds and then calculating the Youden Index for\n        # each.\n        fpr, tpr, thresholds = roc_curve(all_labels, all_peak_heights)\n        youden_index = tpr - fpr\n        # The best Youden index gives us a balance between sensitivity and specificity.\n        self.min_r_peak_height_over_baseline = thresholds[np.argmax(youden_index)]\n        return self"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Testing the implementation\nTo test the trainable implementation, we need a train and a test set.\nIn this case we simply use the first two recordings as train set and a third recording as test set.\n\nThen we first call `self_optimize` with the train data.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "train_data = example_data[:2]\ntrain_ecg_data = [d.data[\"ecg\"] for d in train_data]\ntrain_r_peaks = [d.r_peak_positions_[\"r_peak_position\"] for d in train_data]\n\nalgorithm = OptimizableQrsDetector()\nalgorithm = algorithm.self_optimize(train_ecg_data, train_r_peaks, train_data.sampling_rate_hz)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "After the optimization, we can access the modified parameters.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(\n    \"The optimized value of the threshold `min_r_peak_height_over_baseline` is:\",\n    algorithm.min_r_peak_height_over_baseline,\n)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Then we can apply the algorithm to our test set.\nAnd again, we can see that the algorithm works fine on the piece of data we are inspecting here.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "test_data = example_data[3]\ntest_ecg_data = test_data.data[\"ecg\"]\n\nalgorithm = algorithm.detect(test_ecg_data, test_data.sampling_rate_hz)\n\n# Visualize the results\nplt.figure()\nplt.plot(test_ecg_data[:5000])\nsubset_peaks = algorithm.r_peak_positions_[algorithm.r_peak_positions_ < 5000.0]\nplt.plot(subset_peaks, test_ecg_data[subset_peaks], \"s\")\nplt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.13"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}