{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n\n# The final QRS detection algorithms\n\nThese are the QRS detection algorithms, that we developed step by step `custom_algorithms_qrs_detection`.\nThis file can be used as quick reference or to import the class into other examples without side effects.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from typing import List, Tuple, Union\n\nimport numpy as np\nimport pandas as pd\nfrom scipy import signal\nfrom scipy.spatial import cKDTree, minkowski_distance\nfrom sklearn.metrics import roc_curve\n\nfrom tpcp import Algorithm, HyperParameter, OptimizableParameter, Parameter, make_action_safe, make_optimize_safe\n\n\ndef match_events_with_reference(\n    events: np.ndarray, reference: np.ndarray, tolerance: Union[int, float], one_to_one: bool = True\n) -> Tuple[np.ndarray, np.ndarray]:\n    \"\"\"Find matches in two lists based on the distance between their vectors.\n\n    Parameters\n    ----------\n    events : array with shape (n, d)\n        An n long array of d-dimensional vectors\n    reference : array with shape (m, d)\n        An m long array of d-dimensional vectors\n    tolerance\n        Max allowed Chebyshev distance between matches\n    one_to_one\n        If True only valid one-to-one matches are returned (see more below)\n\n    Returns\n    -------\n    event_indices\n        Indices from the events array that have a match in the right list.\n        If `one_to_one` is False, indices might repeat.\n    reference_indices\n        Indices from the reference array that have a match in the left list.\n        If `one_to_one` is False, indices might repeat.\n        A valid match pare is then `(event_indices[i], reference_indices[i]) for all i.\n\n    Notes\n    -----\n    This function supports 2 modes:\n\n    `one_to_one` = False:\n        In this mode every match is returned as long the distance in all dimensions between the matches is at most\n        tolerance.\n        This is equivalent to the Chebyshev distance between the matches\n        (aka `np.max(np.abs(left_match - right_match)) < tolerance`).\n        This means multiple matches for each vector will be returned.\n        This means the respective indices will occur multiple times in the output vectors.\n    `one_to_one` = True:\n        In this mode only a single match per index is allowed in both directions.\n        This means that every index will only occur once in the output arrays.\n        If multiple matches are possible based on the tolerance of the Chebyshev distance, the closest match will be\n        selected based on the Manhatten distance (aka `np.sum(np.abs(left_match - right_match`).\n        Only this match will be returned.\n        Note, that in the implementation, we first get the closest match based on the Manhatten distance and check in a\n        second step if this closed match is also valid based on the Chebyshev distance.\n\n    \"\"\"\n    if len(events) == 0 or len(reference) == 0:\n        return np.array([]), np.array([])\n\n    events = np.atleast_1d(events.squeeze())\n    reference = np.atleast_1d(reference.squeeze())\n    assert np.ndim(events) == 1, \"Events must be a 1D-array\"\n    assert np.ndim(reference) == 1, \"Reference must be a 1D-array\"\n    events = np.atleast_2d(events).T\n    reference = np.atleast_2d(reference).T\n\n    right_tree = cKDTree(reference)\n    left_tree = cKDTree(events)\n\n    if one_to_one is False:\n        # p = np.inf is used to select the Chebyshev distance\n        keys = list(zip(*right_tree.sparse_distance_matrix(left_tree, tolerance, p=np.inf).keys()))\n        # All values are returned that have a valid match\n        return (np.array([]), np.array([])) if len(keys) == 0 else (np.array(keys[1]), np.array(keys[0]))\n\n    # one_to_one is True\n    # We calculate the closest neighbor based on the Manhatten distance in both directions and then find only the cases\n    # were the right side closest neighbor resulted in the same pairing as the left side closest neighbor ensuring\n    # that we have true one-to-one-matches\n\n    # p = 1 is used to select the Manhatten distance\n    l_nearest_distance, l_nearest_neighbor = right_tree.query(events, p=1, workers=-1)\n    _, r_nearest_neighbor = left_tree.query(reference, p=1, workers=-1)\n\n    # Filter the once that are true one-to-one matches\n    l_indices = np.arange(len(events))\n    combined_indices = np.vstack([l_indices, l_nearest_neighbor]).T\n    boolean_map = r_nearest_neighbor[l_nearest_neighbor] == l_indices\n    valid_matches = combined_indices[boolean_map]\n\n    # Check if the remaining matches are inside our Chebyshev tolerance distance.\n    # If not, delete them.\n    valid_matches_distance = l_nearest_distance[boolean_map]\n    index_large_matches = np.where(valid_matches_distance > tolerance)[0]\n    if index_large_matches.size > 0:\n        # Minkowski with p = np.inf uses the Chebyshev distance\n        output = (\n            minkowski_distance(events[index_large_matches], reference[valid_matches[index_large_matches, 1]], p=np.inf)\n            > tolerance\n        )\n\n        valid_matches = np.delete(valid_matches, index_large_matches[output], axis=0)\n\n    valid_matches = valid_matches.T\n\n    return valid_matches[0], valid_matches[1]\n\n\nclass QRSDetector(Algorithm):\n    _action_methods = \"detect\"\n\n    # Input Parameters\n    high_pass_filter_cutoff_hz: Parameter[float]\n    max_heart_rate_bpm: Parameter[float]\n    min_r_peak_height_over_baseline: Parameter[float]\n\n    # Results\n    r_peak_positions_: pd.Series\n\n    # Some internal constants\n    _HIGH_PASS_FILTER_ORDER: int = 4\n\n    def __init__(\n        self,\n        max_heart_rate_bpm: float = 200.0,\n        min_r_peak_height_over_baseline: float = 1.0,\n        high_pass_filter_cutoff_hz: float = 1,\n    ):\n        self.max_heart_rate_bpm = max_heart_rate_bpm\n        self.min_r_peak_height_over_baseline = min_r_peak_height_over_baseline\n        self.high_pass_filter_cutoff_hz = high_pass_filter_cutoff_hz\n\n    @make_action_safe\n    def detect(self, single_channel_ecg: pd.Series, sampling_rate_hz: float):\n        ecg = single_channel_ecg.to_numpy().flatten()\n\n        filtered_signal = self._filter(ecg, sampling_rate_hz)\n        peak_positions = self._search_strategy(filtered_signal, sampling_rate_hz)\n\n        self.r_peak_positions_ = pd.Series(peak_positions)\n        return self\n\n    def _search_strategy(\n        self, filtered_signal: np.ndarray, sampling_rate_hz: float, use_height: bool = True\n    ) -> np.ndarray:\n        # Calculate the minimal distance based on the expected heart rate\n        min_distance_between_peaks = 1 / (self.max_heart_rate_bpm / 60) * sampling_rate_hz\n\n        height = None\n        if use_height:\n            height = self.min_r_peak_height_over_baseline\n        peaks, _ = signal.find_peaks(filtered_signal, distance=min_distance_between_peaks, height=height)\n        return peaks\n\n    def _filter(self, ecg_signal: np.ndarray, sampling_rate_hz: float) -> np.ndarray:\n        sos = signal.butter(\n            btype=\"high\",\n            N=self._HIGH_PASS_FILTER_ORDER,\n            Wn=self.high_pass_filter_cutoff_hz,\n            output=\"sos\",\n            fs=sampling_rate_hz,\n        )\n        return signal.sosfiltfilt(sos, ecg_signal)\n\n\nclass OptimizableQrsDetector(QRSDetector):\n    min_r_peak_height_over_baseline: OptimizableParameter[float]\n    r_peak_match_tolerance_s: HyperParameter[float]\n\n    def __init__(\n        self,\n        max_heart_rate_bpm: float = 200.0,\n        min_r_peak_height_over_baseline: float = 1.0,\n        r_peak_match_tolerance_s: float = 0.01,\n        high_pass_filter_cutoff_hz: float = 1,\n    ):\n        self.r_peak_match_tolerance_s = r_peak_match_tolerance_s\n        super().__init__(\n            max_heart_rate_bpm=max_heart_rate_bpm,\n            min_r_peak_height_over_baseline=min_r_peak_height_over_baseline,\n            high_pass_filter_cutoff_hz=high_pass_filter_cutoff_hz,\n        )\n\n    @make_optimize_safe\n    def self_optimize(self, ecg_data: List[pd.Series], r_peaks: List[pd.Series], sampling_rate_hz: float):\n        all_labels = []\n        all_peak_heights = []\n        for d, p in zip(ecg_data, r_peaks):\n            filtered = self._filter(d.to_numpy().flatten(), sampling_rate_hz)\n            # Find all potential peaks without the height threshold\n            potential_peaks = self._search_strategy(filtered, sampling_rate_hz, use_height=False)\n            # Determine the label for each peak, by matching them with our ground truth\n            labels = np.zeros(potential_peaks.shape)\n            matches, _ = match_events_with_reference(\n                events=np.atleast_2d(potential_peaks).T,\n                reference=np.atleast_2d(p.to_numpy().astype(int)).T,\n                tolerance=self.r_peak_match_tolerance_s * sampling_rate_hz,\n                one_to_one=True,\n            )\n            labels[matches] = 1\n            labels = labels.astype(bool)\n            all_labels.append(labels)\n            all_peak_heights.append(filtered[potential_peaks])\n        all_labels = np.hstack(all_labels)\n        all_peak_heights = np.hstack(all_peak_heights)\n        # We \"brute-force\" a good cutoff by testing a bunch of thresholds and then calculating the Youden Index for\n        # each.\n        fpr, tpr, thresholds = roc_curve(all_labels, all_peak_heights)\n        youden_index = tpr - fpr\n        # The best Youden index gives us a balance between sensitivity and specificity.\n        self.min_r_peak_height_over_baseline = thresholds[np.argmax(youden_index)]\n        return self"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.13"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}