{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n\n# The final QRS detection algorithms\n\nThese are the QRS detection algorithms, that we developed step by step `custom_algorithms_qrs_detection`.\nThis file can be used as quick reference or to import the class into other examples without side effects.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from typing import List, Tuple, Union\n\nimport numpy as np\nimport pandas as pd\nfrom scipy import signal\nfrom scipy.spatial import KDTree, cKDTree, minkowski_distance\nfrom sklearn.metrics import roc_curve\n\nfrom tpcp import Algorithm, HyperParameter, OptimizableParameter, Parameter, make_action_safe, make_optimize_safe\n\n\ndef match_events_with_reference(events: np.ndarray, reference: np.ndarray, tolerance: Union[int, float]) -> np.ndarray:\n    \"\"\"Find matches in two lists based on the distance between their vectors.\n\n    Parameters\n    ----------\n    events : array with shape (n, d)\n        An n long array of d-dimensional vectors\n    reference : array with shape (m, d)\n        An m long array of d-dimensional vectors\n    tolerance\n        Max allowed Chebyshev distance between matches\n\n    Returns\n    -------\n    A array that marks all matches.\n    If one value is NaN, it means that no match was found for this index.\n\n    Notes\n    -----\n    Only a single match per index is allowed in both directions.\n    This means that every index will only occur once in the output arrays.\n    If multiple matches are possible based on the tolerance of the Chebyshev distance, the closest match will be\n    selected based on the Manhatten distance (aka `np.sum(np.abs(left_match - right_match`).\n    Only this match will be returned.\n    Note, that in the implementation, we first get the closest match based on the Manhatten distance and check in a\n    second step if this closed match is also valid based on the Chebyshev distance.\n\n    \"\"\"\n    if len(events) == 0 or len(reference) == 0:\n        return np.array([])\n\n    events = np.atleast_1d(events.squeeze())\n    reference = np.atleast_1d(reference.squeeze())\n    assert np.ndim(events) == 1, \"Events must be a 1D-array\"\n    assert np.ndim(reference) == 1, \"Reference must be a 1D-array\"\n    events = np.atleast_2d(events).T\n    reference = np.atleast_2d(reference).T\n\n    right_tree = KDTree(reference)\n    left_tree = KDTree(events)\n\n    # We calculate the closest neighbor based on the Manhatten distance in both directions and then find only the cases\n    # were the right side closest neighbor resulted in the same pairing as the left side closest neighbor ensuring\n    # that we have true one-to-one-matches\n    # p = 1 is used to select the Manhatten distance\n    l_nearest_distance, l_nearest_neighbor = right_tree.query(events, p=1, workers=-1)\n    _, r_nearest_neighbor = left_tree.query(reference, p=1, workers=-1)\n\n    # Filter the once that are true one-to-one matches\n    l_indices = np.arange(len(events))\n    combined_indices = np.vstack([l_indices, l_nearest_neighbor]).T\n    boolean_map = r_nearest_neighbor[l_nearest_neighbor] == l_indices\n    valid_matches = combined_indices[boolean_map]\n\n    # Check if the remaining matches are inside our Chebyshev tolerance distance.\n    # If not, delete them.\n    valid_matches_distance = l_nearest_distance[boolean_map]\n    index_large_matches = np.where(valid_matches_distance > tolerance)[0]\n    if index_large_matches.size > 0:\n        # Minkowski with p = np.inf uses the Chebyshev distance\n        output = (\n            minkowski_distance(events[index_large_matches], reference[valid_matches[index_large_matches, 1]], p=np.inf)\n            > tolerance\n        )\n\n        valid_matches = np.delete(valid_matches, index_large_matches[output], axis=0)\n\n    valid_matches = valid_matches\n    # Add invalid pairs to the output array\n    missing_l_indexes = np.setdiff1d(np.arange(len(events)), valid_matches[:, 0])\n    missing_l_matches = np.vstack([missing_l_indexes, np.full(len(missing_l_indexes), np.nan)]).T\n    missing_r_indexes = np.setdiff1d(np.arange(len(reference)), valid_matches[:, 1])\n    missing_r_matches = np.vstack([np.full(len(missing_r_indexes), np.nan), missing_r_indexes]).T\n    valid_matches = np.vstack([valid_matches, missing_l_matches, missing_r_matches])\n\n    return valid_matches\n\n\ndef precision_recall_f1_score(matches: np.ndarray):\n    if len(matches) == 0:\n        return 0, 0, 0\n    n_tp = np.sum((~np.isnan(matches)).all(axis=-1))\n    len_events = np.sum(~np.isnan(matches[:, 0]))\n    len_reference = np.sum(~np.isnan(matches[:, 1]))\n    precision = n_tp / len_events if len_events > 0 else 0\n    recall = n_tp / len_reference if len_reference > 0 else 0\n    f1 = 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0\n    return precision, recall, f1\n\n\nclass QRSDetector(Algorithm):\n    _action_methods = \"detect\"\n\n    # Input Parameters\n    high_pass_filter_cutoff_hz: Parameter[float]\n    max_heart_rate_bpm: Parameter[float]\n    min_r_peak_height_over_baseline: Parameter[float]\n\n    # Results\n    r_peak_positions_: pd.Series\n\n    # Some internal constants\n    _HIGH_PASS_FILTER_ORDER: int = 4\n\n    def __init__(\n        self,\n        max_heart_rate_bpm: float = 200.0,\n        min_r_peak_height_over_baseline: float = 1.0,\n        high_pass_filter_cutoff_hz: float = 1,\n    ):\n        self.max_heart_rate_bpm = max_heart_rate_bpm\n        self.min_r_peak_height_over_baseline = min_r_peak_height_over_baseline\n        self.high_pass_filter_cutoff_hz = high_pass_filter_cutoff_hz\n\n    @make_action_safe\n    def detect(self, single_channel_ecg: pd.Series, sampling_rate_hz: float):\n        ecg = single_channel_ecg.to_numpy().flatten()\n\n        filtered_signal = self._filter(ecg, sampling_rate_hz)\n        peak_positions = self._search_strategy(filtered_signal, sampling_rate_hz)\n\n        self.r_peak_positions_ = pd.Series(peak_positions)\n        return self\n\n    def _search_strategy(\n        self, filtered_signal: np.ndarray, sampling_rate_hz: float, use_height: bool = True\n    ) -> np.ndarray:\n        # Calculate the minimal distance based on the expected heart rate\n        min_distance_between_peaks = 1 / (self.max_heart_rate_bpm / 60) * sampling_rate_hz\n\n        height = None\n        if use_height:\n            height = self.min_r_peak_height_over_baseline\n        peaks, _ = signal.find_peaks(filtered_signal, distance=min_distance_between_peaks, height=height)\n        return peaks\n\n    def _filter(self, ecg_signal: np.ndarray, sampling_rate_hz: float) -> np.ndarray:\n        sos = signal.butter(\n            btype=\"high\",\n            N=self._HIGH_PASS_FILTER_ORDER,\n            Wn=self.high_pass_filter_cutoff_hz,\n            output=\"sos\",\n            fs=sampling_rate_hz,\n        )\n        return signal.sosfiltfilt(sos, ecg_signal)\n\n\nclass OptimizableQrsDetector(QRSDetector):\n    min_r_peak_height_over_baseline: OptimizableParameter[float]\n    r_peak_match_tolerance_s: HyperParameter[float]\n\n    def __init__(\n        self,\n        max_heart_rate_bpm: float = 200.0,\n        min_r_peak_height_over_baseline: float = 1.0,\n        r_peak_match_tolerance_s: float = 0.01,\n        high_pass_filter_cutoff_hz: float = 1,\n    ):\n        self.r_peak_match_tolerance_s = r_peak_match_tolerance_s\n        super().__init__(\n            max_heart_rate_bpm=max_heart_rate_bpm,\n            min_r_peak_height_over_baseline=min_r_peak_height_over_baseline,\n            high_pass_filter_cutoff_hz=high_pass_filter_cutoff_hz,\n        )\n\n    @make_optimize_safe\n    def self_optimize(self, ecg_data: List[pd.Series], r_peaks: List[pd.Series], sampling_rate_hz: float):\n        all_labels = []\n        all_peak_heights = []\n        for d, p in zip(ecg_data, r_peaks):\n            filtered = self._filter(d.to_numpy().flatten(), sampling_rate_hz)\n            # Find all potential peaks without the height threshold\n            potential_peaks = self._search_strategy(filtered, sampling_rate_hz, use_height=False)\n            # Determine the label for each peak, by matching them with our ground truth\n            labels = np.zeros(potential_peaks.shape)\n            matches = match_events_with_reference(\n                events=np.atleast_2d(potential_peaks).T,\n                reference=np.atleast_2d(p.to_numpy().astype(int)).T,\n                tolerance=self.r_peak_match_tolerance_s * sampling_rate_hz,\n            )\n            tp_matches = matches[(~np.isnan(matches)).all(axis=1), 0].astype(int)\n            labels[tp_matches] = 1\n            labels = labels.astype(bool)\n            all_labels.append(labels)\n            all_peak_heights.append(filtered[potential_peaks])\n        all_labels = np.hstack(all_labels)\n        all_peak_heights = np.hstack(all_peak_heights)\n        # We \"brute-force\" a good cutoff by testing a bunch of thresholds and then calculating the Youden Index for\n        # each.\n        fpr, tpr, thresholds = roc_curve(all_labels, all_peak_heights)\n        youden_index = tpr - fpr\n        # The best Youden index gives us a balance between sensitivity and specificity.\n        self.min_r_peak_height_over_baseline = thresholds[np.argmax(youden_index)]\n        return self"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.13"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}