{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n\n# Dataclass and Attrs support\n\nWhen using `tpcp` you have to write a lot of classes with a lot of parameters.\nFor each class you need to repeat all parameter names up to 3 times, even before writing any documentation.\n\nBelow you can see the relevant part of the `QRSDetection` algorithm we implemented in another example.\nEven though it has only 3 parameters, it requires over 20 lines of code to define the basic initialization.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import pandas as pd\n\nfrom tpcp import Algorithm, Parameter\n\n\nclass QRSDetector(Algorithm):\n    _action_methods = \"detect\"\n\n    # Input Parameters\n    high_pass_filter_cutoff_hz: Parameter[float]\n    max_heart_rate_bpm: Parameter[float]\n    min_r_peak_height_over_baseline: Parameter[float]\n\n    # Results\n    r_peak_positions_: pd.Series\n\n    # Some internal constants\n    _HIGH_PASS_FILTER_ORDER: int = 4\n\n    def __init__(\n        self,\n        max_heart_rate_bpm: float = 200.0,\n        min_r_peak_height_over_baseline: float = 1.0,\n        high_pass_filter_cutoff_hz: float = 0.5,\n    ):\n        self.max_heart_rate_bpm = max_heart_rate_bpm\n        self.min_r_peak_height_over_baseline = min_r_peak_height_over_baseline\n        self.high_pass_filter_cutoff_hz = high_pass_filter_cutoff_hz\n\n\nfrom dataclasses import dataclass, field"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Luckily, Python has a built-in solution for that, called `dataclasses`.\nWith that, we can write the class above much more compact.\n\nThe only downside is that the annotation of result fields and constants is a little more verbose, and you **need** to\nmake sure that these parameters are excluded from the init.\nOtherwise, tpcp will explode ;)\n\nNote, if you are using Python >=3.10, we highly recommend to use the `kw_only` option for dataclasses,\nwhich prevent some of the inheritance issues of dataclasses.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from typing import ClassVar\n\n\n@dataclass(repr=False)  # We disable the automatic repr generation, as we have one. The default one might cause errors.\nclass QRSDetector(Algorithm):\n    _action_methods: ClassVar[str] = \"detect\"\n\n    # Input Parameters\n    high_pass_filter_cutoff_hz: Parameter[float] = 200.0\n    max_heart_rate_bpm: Parameter[float] = 1.0\n    min_r_peak_height_over_baseline: Parameter[float] = 0.5\n\n    # Results\n    # We need to add the special field annotation, to exclude the parameter from the init\n    r_peak_positions_: pd.Series = field(init=False, repr=False)\n\n    # Some internal constants\n    # Using the ClassVar annotation, will mark this value as a constant and dataclasses will ignore it.\n    _HIGH_PASS_FILTER_ORDER: ClassVar[int] = 4"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We still get all parameters in the init:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "QRSDetector(high_pass_filter_cutoff_hz=4, max_heart_rate_bpm=200, min_r_peak_height_over_baseline=1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Inheritance\nCreating child classes of `dataclasses` is also simple.\nInstead of repeating all parameters, you just need to specify the new once.\nHowever, you need to make sure that you also apply the `dataclass` decorator to the child class!\n\n... warning :: New parameters will be added at the end in the positional order in the init method.\n               To avoid passing the wrong values to the wrong parameters, we highly recommend to pass parameters\n               only by name and not by position, or use the `kw_only` parameter of dataclasses supported in Python\n               >=3.10.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "@dataclass(repr=False)\nclass ModifiedQRSDetector(QRSDetector):\n    new_parameter: Parameter[float] = 3\n\n\nModifiedQRSDetector(\n    high_pass_filter_cutoff_hz=4, max_heart_rate_bpm=200, min_r_peak_height_over_baseline=1, new_parameter=3\n)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Inheritance from complex tpcp classes\nWhile inheriting from other dataclasses works without issues, be aware that you can not subclass a class that is\nnot a `dataclass` and also has a `__init__` method!\nFor example, you can not subclass :class:`~tpcp.optimize.GridSearch` with a dataclass, as it already defines its own\n`__init__`.\nIn this case you need to use a regular class and manually repeat all parent parameters\n(and call `super().__init__()`).\n\nWhile this might not be a big deal for the GridSearch class, as you are not expected to subclass it on a regular, it\ncan become annoying for classes like `~tpcp.Dataset` and `~tpcp.optimize.optuna.CustomOptunaOptimize`,\nwhich already have an init and you need to subclass to work with them.\nFor these two classes (and other classes with predefined inits, we expect you to subclass from), we provide a\n`as_dataclass` class method that returns a data class version of the respective class:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from itertools import product\n\nfrom tpcp import Dataset\n\n\n@dataclass(repr=False)\nclass CustomDataset(Dataset.as_dataclass()):  # Note the `as_dataclass` call here!\n    def create_index(self) -> pd.DataFrame:\n        return pd.DataFrame(\n            list(product((\"patient_1\", \"patient_2\", \"patient_3\"), (\"test_1\", \"test_2\"), (\"1\", \"2\"))),\n            columns=[\"patient\", \"test\", \"extra\"],\n        )\n\n    custom_param: float = 2  # This must have a default value, as the baseclass has parameters with defautls\n\n\nCustomDataset(custom_param=3)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Mutable Defaults\nIn `tpcp` we usually deal with the issue of mutable defaults by using the :class:`~tpcp.CloneFactory` (\n:func:`~tpcp.cf`).\nHowever, when using dataclasses, we can use the (more elegant) `field` annotation to define mutable defaults.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "@dataclass(repr=False)\nclass FilterAlgorithm(Algorithm):\n    _action_methods: ClassVar = \"filter\"\n\n    # Input Parameters\n    cutoff_hz: Parameter[float] = 2\n    order: Parameter[int] = 5\n\n    # Results\n    filtered_signal_: pd.Series = field(init=False, repr=False)\n\n\n@dataclass\nclass HigherLevelFilter(QRSDetector):\n    filter_algorithm: Parameter[FilterAlgorithm] = field(default_factory=lambda: FilterAlgorithm(3, 2))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We can see that each instance will get a copy of the default value.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "v1 = HigherLevelFilter()\nv2 = HigherLevelFilter()\n\nnested_object_is_different = v1.filter_algorithm is not v2.filter_algorithm\nnested_object_is_different"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Attrs\nA popular alternative to dataclasses is `attrs` (_`attrs.org`).\nIt has the similar features as `dataclasses`, but has some additional features that can be helpfully.\nIt also supports `kw_only` for all Python version (`kw_only` is great! Use it).\n\nYou can use it simply be replacing the `dataclass` decorator with the `attrs.define` decorator in most examples above.\nFurther, `attrs` has a `field` function, that works like `dataclasses.field`.\nOnly the `default_factory` is called `factory`.\n\n<div class=\"alert alert-danger\"><h4>Warning</h4><p>`attrs` creates classes using `slots` instead of `__dict__` by default.\n             This does not work nicely with tpcp!\n             Use the `slot=False` parameter of define.</p></div>\n\nHere are all the classes from above using attrs.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from attrs import Factory, define, field\n\n\n@define(kw_only=True, slots=False, repr=False)  # Slots Don't play nice with tpcp!\nclass QRSDetector(Algorithm):\n    _action_methods: ClassVar[str] = \"detect\"\n\n    # Input Parameters\n    high_pass_filter_cutoff_hz: Parameter[float] = 200.0\n    max_heart_rate_bpm: Parameter[float] = 1.0\n    min_r_peak_height_over_baseline: Parameter[float] = 0.5\n\n    # Results\n    r_peak_positions_: pd.Series = field(init=False)\n\n    # Some internal constants\n    _HIGH_PASS_FILTER_ORDER: ClassVar[int] = 4\n\n\n@define(kw_only=True, slots=False, repr=False)  # Slots Don't play nice with tpcp!\nclass FilterAlgorithm(Algorithm):\n    _action_methods: ClassVar = \"filter\"\n\n    # Input Parameters\n    cutoff_hz: Parameter[float] = 2\n    order: Parameter[int] = 5\n\n    # Results\n    filtered_signal_: pd.Series = field(init=False)\n\n\n@define(kw_only=True, slots=False, repr=False)  # Slots Don't play nice with tpcp!\nclass HigherLevelFilter(QRSDetector):\n    filter_algorithm: Parameter[FilterAlgorithm] = Factory(lambda: FilterAlgorithm(cutoff_hz=3, order=2))\n\n\nHigherLevelFilter()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "To support subclassing tpcp parameters with existing inits, we provide a `as_attrs` method on the respective classes.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "@define(kw_only=True, slots=False, repr=False)  # Slots Don't play nice with tpcp!\nclass CustomDataset(Dataset.as_attrs()):  # Note the `as_attrs` call here!\n\n    custom_param: float  # We don't need a default, as we are using `kw_only` in define\n\n    def create_index(self) -> pd.DataFrame:\n        return pd.DataFrame(\n            list(product((\"patient_1\", \"patient_2\", \"patient_3\"), (\"test_1\", \"test_2\"), (\"1\", \"2\"))),\n            columns=[\"patient\", \"test\", \"extra\"],\n        )\n\n\nCustomDataset(custom_param=3)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.15"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}