r"""
.. _custom_dataset_basics:

Custom Dataset - Basics
=======================

Datasets represent a set of recordings that should all be processed in the same way.
For example the data of multiple participants in a study, multiple days of recording, or multiple tests.
The goal of datasets is to provide a consistent interface to access the raw data, metadata, and potential reference
information in an object-oriented way.
It is up to you to define, what is considered a single "data-point" for your dataset.
Note, that datasets can be arbitrarily nested (e.g. multiple participants with multiple recordings).

Datasets work best in combination with `Pipelines` and are further compatible with concepts like `GridSearch` and
`cross_validation`.
"""

# %%
# Defining your own dataset
# -------------------------
# Fundamentally you only need to create a subclass of :func:`~tpcp.Dataset` and define the
# `create_index` method.
# This method should return a dataframe describing all the data-points that should be available in the dataset.
#
# .. warning:: Make absolutely sure that the dataframe you return is deterministic and does not change between runs!
#              This can lead to some nasty bugs!
#              We try to catch them internally, but it is not always possible.
#              As tips, avoid reliance on random numbers and make sure that the order is not depend on things
#              like file system order, when creating an index by scanning a directory.
#              Particularly nasty are cases when using non-sorted container like `set`, that sometimes maintain
#              their order, but sometimes don't.
#              At the very least, we recommend to sort the final dataframe you return in `create_index`.
#
# In the following we will create an example dataset, without any real world data,
# but it can be used to demonstrate most functionality.
# At the end we will discuss, how gait specific data should be integrated.
#
# We will define an index that contains 5 participants, with 3 recordings each.
# Recording 3 has 2 trials, while the others have only one.
# Note, that we implement this as a static index here, but most of the time, you would create the index by e.g. scanning
# and listing the files in your data directory.
# It is important that you don't want to load the entire actual data (e.g. the imu samples) in memory, but just list
# the available data-points in the index.
# Then you can filter the dataset first and load the data once you know which data-points you want to access.
# We will discuss this later in the example.
from itertools import product
from typing import Optional, Union

import pandas as pd

trials = list(product(("rec_1", "rec_2", "rec_3"), ("trial_1",)))
trials.append(("rec_3", "trial_2"))
index = [(p, *t) for p, t in product((f"p{i}" for i in range(1, 6)), trials)]
index = pd.DataFrame(index, columns=["participant", "recording", "trial"])
index

# %%
# Now we use this index as the index of our new dataset.
# To see the dataset in action, we need to create an instance of it.
# Its string representation will show us the most important information.
from tpcp import Dataset


class CustomDataset(Dataset):
    def create_index(self):
        return index


dataset = CustomDataset()
dataset

# %%
# Subsets
# -------
# When working with a dataset, the first thing is usually to select the data you want to use.
# For this, you can primarily use the method `get_subset`.
# Here we want to select only recording 2 and 3 from participant 1 to 4.
# Note that the returned subset is an instance of your dataset class as well.
subset = dataset.get_subset(
    participant=["p1", "p2", "p3", "p4"], recording=["rec_2", "rec_3"]
)
subset

# %%
# The subset can then be filtered further.
# For more advanced filter approaches you can also filter the index directly and use a bool-map to index the dataset
example_bool_map = subset.index["participant"].isin(["p1", "p2"])
final_subset = subset.get_subset(bool_map=example_bool_map)
final_subset

# %%
# Iteration and Groups
# --------------------------------
# After selecting the part of the data you want to use, you usually want/need to iterate over the data to apply your
# processing steps.
#
# By default, you can simply iterate over all rows.
# Note, that each row itself is a dataset again, but just with a single entry.
for row in final_subset:
    print(row)
    print(f"This row contains {len(row)} data-point", end="\n\n")

# %%
# However, in many cases, we don't want to iterate over all rows, but rather iterate over groups of the datasets (
# e.g. all participants or all tests) individually.
# We can do that in 2 ways (depending on what is needed).
# For example, if we want to iterate over all recordings, we can do this:
for trial in final_subset.iter_level("recording"):
    print(trial, end="\n\n")

# %%
# You can see that we get two subsets, one for each recording label.
# But what, if we want to iterate over the participants and the recordings together?
# In this case, we need to group our dataset first.
# Note that the grouped_subset shows the new groupby columns as the index in the representation and the length of the
# dataset is reported to be the number of groups.
grouped_subset = final_subset.groupby(["participant", "recording"])
print(f"The dataset contains {len(grouped_subset)} groups.")
grouped_subset

# %%
# If we now iterate the dataset, it will iterate over the unique groups.
#
# Grouping also changes the meaning of a "single datapoint".
# Each group reports a shape of `(1,)` independent of the number of rows in each group.
for group in grouped_subset:
    print(f"This group has the shape {group.shape}")
    print(group, end="\n\n")

# %%
# At any point, you can view all unique groups/rows in the dataset using the `group_labels` attribute.
# The order shown here, is the same order used when iterating the dataset.
# When creating a new subset, the order might change!
grouped_subset.group_labels

# %%
# .. note:: The `group_labels` attribute consists of a list of `named tuples
#           <https://docs.python.org/3/library/collections.html#
#           namedtuple-factory-function-for-tuples-with-named-fields>`_.
#           The tuple elements are named after the groupby columns and are in the same order as the groupby columns.
#           They can be accessed by name or index:
#           For example, `grouped_subset.group_labels[0].participant` and `grouped_subset.group_labels[0][0]` are equivalent.
#
#           Also, `grouped_subset.group_labels[0]` and `grouped_subset[0].group_label` are equivalent.


# %%
# Note that for an "un-grouped" dataset, this corresponds to all rows.
final_subset.group_labels

# %%
# If you want to view the full set of labels of a dataset regardless of the grouping,
# you can use the `index_as_tuples` method.
grouped_subset.index_as_tuples()

# %%
# Note that `index_as_tuples()` and `group_labels` return the same for an un-grouped dataset.
final_subset.index_as_tuples()

# %%
# We can use the group labels (or a subset of them) to index our dataset.
# This can be in particular helpful, if you want to recreate specific train test splits provided by `cross_validate`.
final_subset.get_subset(group_labels=final_subset.group_labels[:3])

# %%
# If you want, you can also ungroup a dataset again.
# This can be useful for a nested iteration:
for outer, group in enumerate(grouped_subset):
    ungrouped = group.groupby(None)
    for inner, subgroup in enumerate(ungrouped):
        print(outer, inner)
        print(subgroup, end="\n\n")

# %%
# Splitting
# ---------
# If you are evaluating algorithms, it is often important to split your data into a train and a test set, or multiple
# distinct sets for a cross validation.
#
# The `Dataset` objects directly support the `sklearn` helper functions for this.
# For example, to split our subset into training and testing we can do the following:
from sklearn.model_selection import train_test_split

train, test = train_test_split(final_subset, train_size=0.5)
print("Train:\n", train, end="\n\n")
print("Test:\n", test)

# %%
# Such splitting always occurs on a data-point level and can therefore be influenced by grouping.
# If we want to split our datasets into training and testing, but only based on the participants, we can do this:
train, test = train_test_split(
    final_subset.groupby("participant"), train_size=0.5
)
print("Train:\n", train, end="\n\n")
print("Test:\n", test)

# %%
# In the same way you can use the dataset (grouped or not) with the cross validation helper functions
# (KFold is just an example, all should work):
from sklearn.model_selection import KFold

cv = KFold(n_splits=2)
grouped_subset = final_subset.groupby("participant")
for train, test in cv.split(grouped_subset):
    # We only print the train set here
    print(grouped_subset[train], end="\n\n")

# %%
# While this works well, it is not always what we want.
# Sometimes, we still want to consider each row a single datapoint, but want to prevent that data of e.g. a single
# participant and recording is partially put into train- and partially into the test-split.
# For this, we can use `GroupKFold` in combination with `dataset.create_string_group_labels`.
#
# `create_string_group_labels` generates a unique string identifier for each row/group:
group_labels = final_subset.create_string_group_labels(
    ["participant", "recording"]
)
group_labels

# %%
# They can then be used as the `group` parameter in `GroupKFold` (and similar methods).
# Now the data of the two participants is never split between train and test set.
from sklearn.model_selection import GroupKFold

cv = GroupKFold(n_splits=2)
for train, test in cv.split(final_subset, groups=group_labels):
    # We only print the train set here
    print(final_subset[train], end="\n\n")

# %%
# Instead of doing this manually, we also provide a custom splitter that does this for you.
# It allows us to directly put the dataset into the `split` method of `cross_validate` and use higher level semantics
# to specify the grouping and stratification.
from tpcp.validate import DatasetSplitter

cv = DatasetSplitter(
    GroupKFold(n_splits=2), groupby=["participant", "recording"]
)

for train, test in cv.split(final_subset):
    # We only print the train set here
    print(final_subset[train], end="\n\n")


# %%
# Creating labels also works for datasets that are already grouped.
# But, the columns that should be contained in the label must be a subset of the groupby columns in this case.
#
# The number of group labels is 4 in this case, as there are only 4 groups after grouping the datset.
group_labels = final_subset.groupby(
    ["participant", "recording"]
).create_string_group_labels("participant")
group_labels

# %%
# Adding Data
# -----------
# So far we only operated on the index of the dataset.
# But if we want to run algorithms, we need the actual data (i.e. IMU samples, clinical data, ...).
#
# Because the data and the structure of the data can vary widely from dataset to dataset, it is up to you to implement
# data access.
# It comes down to documentation to ensure that users access the data in the correct way.
#
# In general, we try to follow a couple of conventions to give datasets a consistent feel:
#
# - Data access should be provided via `@property` decorator on the dataset objects, loading the data on demand.
# - The names of these properties should follow some the naming scheme (e.g. `data` for the core sensor data)
#   and should return values using the established datatypes (e.g. `pd.DataFrames`).
# - The names of values that represent gold standard information (i.e. values you would only have in an evaluation
#   dataset and should never use for training), should have a trailing `_`, which marks them as result similar how
#   sklearn handles it.
#
# This should look something like this:


class CustomDataset(Dataset):
    @property
    def data(self) -> pd.DataFrame:
        # Some logic to load data from disc
        raise NotImplementedError()

    @property
    def sampling_rate_hz(self) -> float:
        return 204.8

    @property
    def reference_events_(self) -> pd.DataFrame:
        # Some custom logic to load the gold-standard events of this validation dataset.
        # Note the trailing `_` in the name.
        raise NotImplementedError()

    def create_index(self):
        return index


# %%
# For each of the data-values you need to decide, on which "level" you provide data access.
# Meaning, do you want/can return data, when there are still multiple participants/recordings in the dataset, or can you
# only return the data, when there is only a single trial of a single participant left.
#
# Usually, we recommend to always return the data on the lowest logical level (e.g. if you recorded separate IMU
# sessions per trial, you should provide access only, if there is just a single trail by a single participant left in
# the dataset).
# Otherwise, you should throw an error.
# This pattern can be simplified using the `is_single` or `assert_is_single` helper method.
# These helpers check based on the provided `groupby_cols` if there is really just a single group/row left with the
# given groupby settings.
#
# Let's say `data` can be accessed on either a `recording` or a `trail` level, and `segmented_stride_list` can only
# be accessed on a `trail` level.
# Then we could do something like this:


class CustomDataset(Dataset):
    @property
    def data(self) -> str:
        # Note that we need to make our checks from the least restrictive to the most restrictive (if there is only a
        # single trail, there is only just a single recording).
        if self.is_single(["participant", "recording"]):
            return "This is the data for participant {} and rec {}".format(
                *self.group_label
            )
        # None -> single row
        if self.is_single(None):
            return "This is the data for participant {}, rec {} and trial {}".format(
                *self.group_label
            )
        raise ValueError(
            "Data can only be accessed when their is only a single recording of a single participant in the subset"
        )

    @property
    def sampling_rate_hz(self) -> float:
        return 204.8

    @property
    def segmented_stride_list_(self) -> str:
        # We use assert here, as we don't have multiple options.
        # (We could also used `None` for the `groupby_cols` here)
        self.assert_is_single(
            ["participant", "recording", "trial"], "segmented_stride_list_"
        )
        return "This is the segmented stride list for participant {}, rec {} and trial {}".format(
            *self.group_label
        )

    def create_index(self):
        return index


# %%
# If we select a single trial (row), we can get data and the stride list:
test_dataset = CustomDataset()
single_trial = test_dataset[0]
print(single_trial.data)
print(single_trial.segmented_stride_list_)

# %%
# If we only select a recording, we get an error for the stride list:

# We select only recording 3 here, as it has 2 trials.
single_recording = test_dataset.get_subset(recording="rec_3").groupby(
    ["participant", "recording"]
)[0]
print(single_recording.data)
try:
    print(single_recording.segmented_stride_list_)
except Exception as e:
    print("ValueError: ", e)

# %%
# Custom parameter
# ----------------
# Often it is required to pass some parameters/configuration to the dataset.
# This could be for example the place where the data is stored or if a specific part of the dataset should be included,
# if some preprocessing should be applied to the data, ... .
#
# Such additional configuration can be provided via a custom `__init__` and is then available for all methods to be
# used.
# Note that you **must** assign the configuration values to attributes with the same name and **must not** forget to
# call `super().__init__`


class CustomDatasetWithConfig(Dataset):
    data_folder: str
    custom_config_para: bool

    def __init__(
        self,
        data_folder: str,
        custom_config_para: bool = False,
        *,
        groupby_cols: Optional[Union[list[str], str]] = None,
        subset_index: Optional[pd.DataFrame] = None,
    ):
        self.data_folder = data_folder
        self.custom_config_para = custom_config_para
        super().__init__(groupby_cols=groupby_cols, subset_index=subset_index)

    def create_index(self):
        # Use e.g. `self.data_folder` to load the data.
        return index
