Note
Click here to download the full example code
Tensorflow/Keras#
Note
This example requires the tensorflow
package to be installed.
Theoretically, tpcp is framework agnostic and can be used with any framework. However, due to the way some frameworks handle their objects, some special handling internally is required. Hence, this example does not only serve as example on how to use tensorflow with tpcp, but also as a test case for these special cases.
When using tpcp with any machine learning framework, you either want to use a pretrained model with a normal pipeline or a train your own model as part of an Optimizable Pipeline. Here we show the second case, as it is more complex, and you are likely able to figure out the first case yourself.
This means, we are planning to perform the following steps:
Create a pipeline that creates and trains a model.
Allow the modification of model hyperparameters.
Run a simple cross-validation to demonstrate the functionality.
This example reimplements the basic MNIST example from the [tensorflow documentation](https://www.tensorflow.org/tutorials/keras/classification).
Some Notes#
In this example we show how to implement a Pipeline that uses tensorflow. You could implement an Algorithm in a similar way. This would actually be easier, as no specific handling of the input data would be required. For a pipeline, we need to create a custom Dataset class, as this is the expected input for a pipeline.
The Dataset#
We are using the normal fashion MNIST dataset for this example It consists of 60.000 images of 28x28 pixels, each with a label. We will ignore the typical train-test split, as we want to do our own cross-validation.
In addition, we will simulate an additional “index level”. In this (and most typical deep learning datasets), each datapoint is one vector for which we can make one prediction. In tpcp, we usually deal with datasets, where you might have multiple pieces of information for each datapoint. For example, one datapoint could be a patient, for which we have an entire time series of measurements. We will simulate this here, by creating the index of our dataset as 1000 groups each containing 60 images.
Other than that, the dataset is pretty standard.
Besides the create_index
method, we only need to implement the input_as_array
and labels_as_array
methods that
allow us to easily access the data once we selected a single group.
from functools import lru_cache
import numpy as np
import pandas as pd
import tensorflow as tf
from tpcp import Dataset
tf.keras.utils.set_random_seed(812)
tf.config.experimental.enable_op_determinism()
@lru_cache(maxsize=1)
def get_fashion_mnist_data():
# Note: We throw train and test sets together, as we don't care about the official split here.
# We will create our own split later.
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.fashion_mnist.load_data()
return np.array(list(train_images) + list(test_images)), list(train_labels) + list(test_labels)
class FashionMNIST(Dataset):
def input_as_array(self) -> np.ndarray:
self.assert_is_single(None, "input_as_array")
group_id = int(self.group)
images, _ = get_fashion_mnist_data()
return images[group_id * 60 : (group_id + 1) * 60].reshape((60, 28, 28)) / 255
def labels_as_array(self) -> np.ndarray:
self.assert_is_single(None, "labels_as_array")
group_id = int(self.group)
_, labels = get_fashion_mnist_data()
return np.array(labels[group_id * 60 : (group_id + 1) * 60])
def create_index(self) -> pd.DataFrame:
# There are 60.000 images in total.
# We simulate 1000 groups of 60 images each.
return pd.DataFrame({"group_id": list(range(1000))})
We can see our Dataset works as expected:
dataset = FashionMNIST()
dataset[0].input_as_array().shape
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
8192/29515 [=======>......................] - ETA: 0s
29515/29515 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
8192/26421880 [..............................] - ETA: 0s
663552/26421880 [..............................] - ETA: 1s
6266880/26421880 [======>.......................] - ETA: 0s
12681216/26421880 [=============>................] - ETA: 0s
16785408/26421880 [==================>...........] - ETA: 0s
24223744/26421880 [==========================>...] - ETA: 0s
26421880/26421880 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
5148/5148 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz
8192/4422102 [..............................] - ETA: 0s
794624/4422102 [====>.........................] - ETA: 0s
4422102/4422102 [==============================] - 0s 0us/step
(60, 28, 28)
dataset[0].labels_as_array().shape
(60,)
The Pipeline#
We will create a pipeline that uses a simple neural network to classify the images.
In tpcp, all “things” that should be optimized need to be parameters.
This means our model itself needs to be a parameter of the pipeline.
However, as we don’t have the model yet, as its creation depends on other hyperparameters, we add it as an optional
parameter initialized with None
.
Further, we prefix the parameter name with an underscore, to signify, that this is not a parameter that should be
modified manually by the user.
This is just convention, and it is up to you to decide how you want to name your parameters.
We further introduce a hyperparameter n_dense_layer_nodes
to show how we can influence the model creation.
The optimize method#
To make our pipeline optimizable, it needs to inherit from OptimizablePipeline
.
Further we need to mark at least one of the parameters as OptiPara
using the type annotation.
We do this for our _model
parameter.
Finally, we need to implement the self_optimize
method.
This method will get the entire training dataset as input and should update the _model
parameter with the trained
model.
Hence, we first extract the relevant data (remember, each datapoint is 60 images), by concatinating all images over
all groups in the dataset.
Then we create the Keras model based on the hyperparameters.
Finally, we train the model and update the _model
parameter.
Here we chose to wrap the method with make_optimize_safe
.
This decorator will perform some runtime checks to ensure that the method is implemented correctly.
The run method#
The run method expects that the _model
parameter is already set (i.e. the pipeline was already optimized).
It gets a single datapoint as input (remember, a datapoint is a single group of 60 images).
We then extract the data from the datapoint and let the model make a prediction.
We store the prediction on our output attribute predictions_
.
The trailing underscore is a convention to signify, that this is an “result” attribute.
from tpcp import OptimizablePipeline, OptiPara, make_optimize_safe, make_action_safe
from typing import Optional, Tuple
from typing_extensions import Self
import warnings
class KerasPipeline(OptimizablePipeline):
n_dense_layer_nodes: int
n_train_epochs: int
_model: OptiPara[Optional[tf.keras.Sequential]]
predictions_: np.ndarray
def __init__(self, n_dense_layer_nodes=128, n_train_epochs=5, _model: Optional[tf.keras.Sequential] = None):
self.n_dense_layer_nodes = n_dense_layer_nodes
self.n_train_epochs = n_train_epochs
self._model = _model
@property
def predicted_labels_(self):
return np.argmax(self.predictions_, axis=1)
@make_optimize_safe
def self_optimize(self, dataset, **_) -> Self:
data = np.vstack([d.input_as_array() for d in dataset])
labels = np.hstack([d.labels_as_array() for d in dataset])
print(data.shape)
if self._model is not None:
warnings.warn("Overwriting existing model!")
self._model = tf.keras.Sequential(
[
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(self.n_dense_layer_nodes, activation="relu"),
tf.keras.layers.Dense(10),
]
)
self._model.compile(
optimizer="adam",
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=["accuracy"],
)
self._model.fit(data, labels, epochs=self.n_train_epochs)
return self
@make_action_safe
def run(self, datapoint) -> Self:
if self._model is None:
raise RuntimeError("Model not trained yet!")
data = datapoint.input_as_array()
self.predictions_ = self._model.predict(data)
return self
Testing the pipeline#
We can now test our pipeline.
We will run the optimization using a couple of datapoints (to keep everything fast) and then use run
to get the
predictions for a single unseen datapoint.
pipeline = KerasPipeline().self_optimize(FashionMNIST()[:10])
p1 = pipeline.run(FashionMNIST()[11])
print(p1.predicted_labels_)
print(FashionMNIST()[11].labels_as_array())
(600, 28, 28)
Epoch 1/5
1/19 [>.............................] - ETA: 14s - loss: 2.3593 - accuracy: 0.1250
19/19 [==============================] - 1s 3ms/step - loss: 1.5112 - accuracy: 0.4933
Epoch 2/5
1/19 [>.............................] - ETA: 0s - loss: 0.9220 - accuracy: 0.7188
19/19 [==============================] - 0s 2ms/step - loss: 0.8705 - accuracy: 0.7083
Epoch 3/5
1/19 [>.............................] - ETA: 0s - loss: 1.0179 - accuracy: 0.6875
19/19 [==============================] - 0s 2ms/step - loss: 0.6799 - accuracy: 0.7850
Epoch 4/5
1/19 [>.............................] - ETA: 0s - loss: 0.4521 - accuracy: 0.8750
19/19 [==============================] - 0s 2ms/step - loss: 0.5962 - accuracy: 0.8133
Epoch 5/5
1/19 [>.............................] - ETA: 0s - loss: 0.4230 - accuracy: 0.8438
19/19 [==============================] - 0s 2ms/step - loss: 0.5158 - accuracy: 0.8400
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
[8 8 0 9 6 0 7 3 7 9 3 8 6 3 7 8 1 4 0 7 9 8 5 5 2 1 3 3 1 9 7 5 9 9 7 8 2
7 2 7 2 6 7 1 1 7 5 4 8 3 5 9 0 7 3 0 0 9 1 9]
[8 8 0 9 2 0 7 3 7 9 3 8 4 3 7 8 1 4 0 7 9 8 5 5 2 1 3 4 6 7 7 5 9 9 7 8 2
7 4 7 0 3 5 1 1 5 5 2 8 3 5 9 0 7 3 0 0 7 1 9]
We can see that even with just 5 epochs, the model already performs quite well. To quantify we can calculate the accuracy for this datapoint:
from sklearn.metrics import accuracy_score
accuracy_score(p1.predicted_labels_, FashionMNIST()[11].labels_as_array())
0.8
Cross Validation#
If we want to run a cross validation, we need to formalize the scoring into a function. We will calculate two types of accuracy: First, the accuracy per group and second, the accuracy over all images across all groups. For more information about how this works, check the custom_scorer example.
from typing import Sequence, Dict
from tpcp.validate import Aggregator
class SingleValueAccuracy(Aggregator[np.ndarray]):
RETURN_RAW_SCORES = False
@classmethod
def aggregate(cls, /, values: Sequence[Tuple[np.ndarray, np.ndarray]], **_) -> Dict[str, float]:
return {"accuracy": accuracy_score(np.hstack([v[0] for v in values]), np.hstack([v[1] for v in values]))}
def scoring(pipeline, datapoint):
result: np.ndarray = pipeline.safe_run(datapoint).predicted_labels_
reference = datapoint.labels_as_array()
return {
"accuracy": accuracy_score(result, reference),
"per_sample": SingleValueAccuracy((result, reference)),
}
Now we can run a cross validation. We will only run it on a subset of the data, to keep the runtime manageable.
Note
You might see warnings about retracing of the model. This is because we clone the pipeline before each call to the run method. This is a good idea to ensure that all pipelines are independent of each other, however, might result in some performance overhead.
from tpcp.validate import cross_validate
from tpcp.optimize import Optimize
pipeline = KerasPipeline(n_train_epochs=10)
cv_results = cross_validate(Optimize(pipeline), FashionMNIST()[:100], scoring=scoring, cv=3)
CV Folds: 0%| | 0/3 [00:00<?, ?it/s](3960, 28, 28)
Epoch 1/10
1/124 [..............................] - ETA: 1:15 - loss: 2.5074 - accuracy: 0.0938
22/124 [====>.........................] - ETA: 0s - loss: 1.5346 - accuracy: 0.4389
44/124 [=========>....................] - ETA: 0s - loss: 1.2142 - accuracy: 0.5639
66/124 [==============>...............] - ETA: 0s - loss: 1.0575 - accuracy: 0.6236
87/124 [====================>.........] - ETA: 0s - loss: 0.9862 - accuracy: 0.6523
108/124 [=========================>....] - ETA: 0s - loss: 0.9299 - accuracy: 0.6756
124/124 [==============================] - 1s 2ms/step - loss: 0.8968 - accuracy: 0.6919
Epoch 2/10
1/124 [..............................] - ETA: 0s - loss: 1.0481 - accuracy: 0.7500
22/124 [====>.........................] - ETA: 0s - loss: 0.5900 - accuracy: 0.8068
44/124 [=========>....................] - ETA: 0s - loss: 0.5977 - accuracy: 0.8011
66/124 [==============>...............] - ETA: 0s - loss: 0.5724 - accuracy: 0.8045
88/124 [====================>.........] - ETA: 0s - loss: 0.5648 - accuracy: 0.8086
110/124 [=========================>....] - ETA: 0s - loss: 0.5577 - accuracy: 0.8116
124/124 [==============================] - 0s 2ms/step - loss: 0.5606 - accuracy: 0.8104
Epoch 3/10
1/124 [..............................] - ETA: 0s - loss: 0.7180 - accuracy: 0.6250
23/124 [====>.........................] - ETA: 0s - loss: 0.5287 - accuracy: 0.8084
45/124 [=========>....................] - ETA: 0s - loss: 0.5210 - accuracy: 0.8188
67/124 [===============>..............] - ETA: 0s - loss: 0.5048 - accuracy: 0.8232
88/124 [====================>.........] - ETA: 0s - loss: 0.4930 - accuracy: 0.8267
110/124 [=========================>....] - ETA: 0s - loss: 0.4960 - accuracy: 0.8247
124/124 [==============================] - 0s 2ms/step - loss: 0.5009 - accuracy: 0.8232
Epoch 4/10
1/124 [..............................] - ETA: 0s - loss: 0.4630 - accuracy: 0.8125
23/124 [====>.........................] - ETA: 0s - loss: 0.4163 - accuracy: 0.8560
45/124 [=========>....................] - ETA: 0s - loss: 0.4155 - accuracy: 0.8562
67/124 [===============>..............] - ETA: 0s - loss: 0.4373 - accuracy: 0.8493
89/124 [====================>.........] - ETA: 0s - loss: 0.4475 - accuracy: 0.8452
111/124 [=========================>....] - ETA: 0s - loss: 0.4399 - accuracy: 0.8466
124/124 [==============================] - 0s 2ms/step - loss: 0.4448 - accuracy: 0.8437
Epoch 5/10
1/124 [..............................] - ETA: 0s - loss: 0.6924 - accuracy: 0.6250
23/124 [====>.........................] - ETA: 0s - loss: 0.4502 - accuracy: 0.8247
45/124 [=========>....................] - ETA: 0s - loss: 0.4324 - accuracy: 0.8410
66/124 [==============>...............] - ETA: 0s - loss: 0.4297 - accuracy: 0.8414
87/124 [====================>.........] - ETA: 0s - loss: 0.4112 - accuracy: 0.8527
108/124 [=========================>....] - ETA: 0s - loss: 0.4116 - accuracy: 0.8547
124/124 [==============================] - 0s 2ms/step - loss: 0.4119 - accuracy: 0.8540
Epoch 6/10
1/124 [..............................] - ETA: 0s - loss: 0.5976 - accuracy: 0.7812
23/124 [====>.........................] - ETA: 0s - loss: 0.4029 - accuracy: 0.8628
45/124 [=========>....................] - ETA: 0s - loss: 0.3856 - accuracy: 0.8660
67/124 [===============>..............] - ETA: 0s - loss: 0.3839 - accuracy: 0.8661
89/124 [====================>.........] - ETA: 0s - loss: 0.3874 - accuracy: 0.8652
111/124 [=========================>....] - ETA: 0s - loss: 0.3817 - accuracy: 0.8677
124/124 [==============================] - 0s 2ms/step - loss: 0.3792 - accuracy: 0.8687
Epoch 7/10
1/124 [..............................] - ETA: 0s - loss: 0.4009 - accuracy: 0.8438
23/124 [====>.........................] - ETA: 0s - loss: 0.3659 - accuracy: 0.8655
45/124 [=========>....................] - ETA: 0s - loss: 0.3446 - accuracy: 0.8743
67/124 [===============>..............] - ETA: 0s - loss: 0.3595 - accuracy: 0.8741
89/124 [====================>.........] - ETA: 0s - loss: 0.3600 - accuracy: 0.8761
111/124 [=========================>....] - ETA: 0s - loss: 0.3564 - accuracy: 0.8770
124/124 [==============================] - 0s 2ms/step - loss: 0.3555 - accuracy: 0.8768
Epoch 8/10
1/124 [..............................] - ETA: 0s - loss: 0.2810 - accuracy: 0.9062
23/124 [====>.........................] - ETA: 0s - loss: 0.3007 - accuracy: 0.9144
45/124 [=========>....................] - ETA: 0s - loss: 0.2950 - accuracy: 0.9056
66/124 [==============>...............] - ETA: 0s - loss: 0.3138 - accuracy: 0.8954
88/124 [====================>.........] - ETA: 0s - loss: 0.3267 - accuracy: 0.8910
110/124 [=========================>....] - ETA: 0s - loss: 0.3369 - accuracy: 0.8875
124/124 [==============================] - 0s 2ms/step - loss: 0.3464 - accuracy: 0.8826
Epoch 9/10
1/124 [..............................] - ETA: 0s - loss: 0.2351 - accuracy: 0.9375
23/124 [====>.........................] - ETA: 0s - loss: 0.3427 - accuracy: 0.8845
45/124 [=========>....................] - ETA: 0s - loss: 0.3402 - accuracy: 0.8792
67/124 [===============>..............] - ETA: 0s - loss: 0.3301 - accuracy: 0.8834
89/124 [====================>.........] - ETA: 0s - loss: 0.3242 - accuracy: 0.8897
111/124 [=========================>....] - ETA: 0s - loss: 0.3206 - accuracy: 0.8905
124/124 [==============================] - 0s 2ms/step - loss: 0.3196 - accuracy: 0.8886
Epoch 10/10
1/124 [..............................] - ETA: 0s - loss: 0.2527 - accuracy: 0.9688
23/124 [====>.........................] - ETA: 0s - loss: 0.2408 - accuracy: 0.9130
45/124 [=========>....................] - ETA: 0s - loss: 0.2670 - accuracy: 0.9062
67/124 [===============>..............] - ETA: 0s - loss: 0.2940 - accuracy: 0.8974
89/124 [====================>.........] - ETA: 0s - loss: 0.2965 - accuracy: 0.8971
111/124 [=========================>....] - ETA: 0s - loss: 0.2938 - accuracy: 0.8998
124/124 [==============================] - 0s 2ms/step - loss: 0.2921 - accuracy: 0.8997
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 1ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 1ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
CV Folds: 33%|###3 | 1/3 [00:09<00:19, 9.93s/it](4020, 28, 28)
Epoch 1/10
1/126 [..............................] - ETA: 1:18 - loss: 2.3656 - accuracy: 0.1250
23/126 [====>.........................] - ETA: 0s - loss: 1.5016 - accuracy: 0.5014
44/126 [=========>....................] - ETA: 0s - loss: 1.2303 - accuracy: 0.5881
66/126 [==============>...............] - ETA: 0s - loss: 1.1033 - accuracy: 0.6264
88/126 [===================>..........] - ETA: 0s - loss: 0.9994 - accuracy: 0.6626
110/126 [=========================>....] - ETA: 0s - loss: 0.9302 - accuracy: 0.6858
126/126 [==============================] - 1s 2ms/step - loss: 0.8913 - accuracy: 0.7000
Epoch 2/10
1/126 [..............................] - ETA: 0s - loss: 0.4984 - accuracy: 0.8125
22/126 [====>.........................] - ETA: 0s - loss: 0.5655 - accuracy: 0.8011
44/126 [=========>....................] - ETA: 0s - loss: 0.5703 - accuracy: 0.7969
66/126 [==============>...............] - ETA: 0s - loss: 0.5937 - accuracy: 0.7879
88/126 [===================>..........] - ETA: 0s - loss: 0.5803 - accuracy: 0.7915
110/126 [=========================>....] - ETA: 0s - loss: 0.5705 - accuracy: 0.7972
126/126 [==============================] - 0s 2ms/step - loss: 0.5662 - accuracy: 0.8002
Epoch 3/10
1/126 [..............................] - ETA: 0s - loss: 0.5249 - accuracy: 0.8125
23/126 [====>.........................] - ETA: 0s - loss: 0.5202 - accuracy: 0.8247
45/126 [=========>....................] - ETA: 0s - loss: 0.4784 - accuracy: 0.8333
67/126 [==============>...............] - ETA: 0s - loss: 0.4676 - accuracy: 0.8396
89/126 [====================>.........] - ETA: 0s - loss: 0.4844 - accuracy: 0.8322
111/126 [=========================>....] - ETA: 0s - loss: 0.4823 - accuracy: 0.8328
126/126 [==============================] - 0s 2ms/step - loss: 0.4893 - accuracy: 0.8279
Epoch 4/10
1/126 [..............................] - ETA: 0s - loss: 0.2082 - accuracy: 0.9688
23/126 [====>.........................] - ETA: 0s - loss: 0.4260 - accuracy: 0.8533
45/126 [=========>....................] - ETA: 0s - loss: 0.4435 - accuracy: 0.8465
67/126 [==============>...............] - ETA: 0s - loss: 0.4278 - accuracy: 0.8517
89/126 [====================>.........] - ETA: 0s - loss: 0.4350 - accuracy: 0.8522
111/126 [=========================>....] - ETA: 0s - loss: 0.4473 - accuracy: 0.8440
126/126 [==============================] - 0s 2ms/step - loss: 0.4431 - accuracy: 0.8453
Epoch 5/10
1/126 [..............................] - ETA: 0s - loss: 0.1979 - accuracy: 0.9375
23/126 [====>.........................] - ETA: 0s - loss: 0.3909 - accuracy: 0.8628
45/126 [=========>....................] - ETA: 0s - loss: 0.4036 - accuracy: 0.8521
67/126 [==============>...............] - ETA: 0s - loss: 0.4035 - accuracy: 0.8568
89/126 [====================>.........] - ETA: 0s - loss: 0.3991 - accuracy: 0.8581
111/126 [=========================>....] - ETA: 0s - loss: 0.4035 - accuracy: 0.8573
126/126 [==============================] - 0s 2ms/step - loss: 0.4046 - accuracy: 0.8580
Epoch 6/10
1/126 [..............................] - ETA: 0s - loss: 0.4598 - accuracy: 0.8125
23/126 [====>.........................] - ETA: 0s - loss: 0.3649 - accuracy: 0.8587
45/126 [=========>....................] - ETA: 0s - loss: 0.3789 - accuracy: 0.8597
67/126 [==============>...............] - ETA: 0s - loss: 0.3761 - accuracy: 0.8643
89/126 [====================>.........] - ETA: 0s - loss: 0.3750 - accuracy: 0.8659
111/126 [=========================>....] - ETA: 0s - loss: 0.3686 - accuracy: 0.8691
126/126 [==============================] - 0s 2ms/step - loss: 0.3772 - accuracy: 0.8667
Epoch 7/10
1/126 [..............................] - ETA: 0s - loss: 0.2244 - accuracy: 0.9062
23/126 [====>.........................] - ETA: 0s - loss: 0.3537 - accuracy: 0.8791
45/126 [=========>....................] - ETA: 0s - loss: 0.3227 - accuracy: 0.8882
67/126 [==============>...............] - ETA: 0s - loss: 0.3343 - accuracy: 0.8839
89/126 [====================>.........] - ETA: 0s - loss: 0.3499 - accuracy: 0.8792
112/126 [=========================>....] - ETA: 0s - loss: 0.3487 - accuracy: 0.8795
126/126 [==============================] - 0s 2ms/step - loss: 0.3479 - accuracy: 0.8786
Epoch 8/10
1/126 [..............................] - ETA: 0s - loss: 0.2092 - accuracy: 0.9375
23/126 [====>.........................] - ETA: 0s - loss: 0.2916 - accuracy: 0.9035
45/126 [=========>....................] - ETA: 0s - loss: 0.3055 - accuracy: 0.8944
67/126 [==============>...............] - ETA: 0s - loss: 0.3109 - accuracy: 0.8904
89/126 [====================>.........] - ETA: 0s - loss: 0.3282 - accuracy: 0.8869
110/126 [=========================>....] - ETA: 0s - loss: 0.3264 - accuracy: 0.8884
126/126 [==============================] - 0s 2ms/step - loss: 0.3285 - accuracy: 0.8863
Epoch 9/10
1/126 [..............................] - ETA: 0s - loss: 0.1886 - accuracy: 0.9062
23/126 [====>.........................] - ETA: 0s - loss: 0.2898 - accuracy: 0.9035
45/126 [=========>....................] - ETA: 0s - loss: 0.3029 - accuracy: 0.8951
67/126 [==============>...............] - ETA: 0s - loss: 0.3126 - accuracy: 0.8806
89/126 [====================>.........] - ETA: 0s - loss: 0.3135 - accuracy: 0.8838
111/126 [=========================>....] - ETA: 0s - loss: 0.3108 - accuracy: 0.8882
126/126 [==============================] - 0s 2ms/step - loss: 0.3120 - accuracy: 0.8900
Epoch 10/10
1/126 [..............................] - ETA: 0s - loss: 0.2327 - accuracy: 0.9062
23/126 [====>.........................] - ETA: 0s - loss: 0.2993 - accuracy: 0.8791
45/126 [=========>....................] - ETA: 0s - loss: 0.3019 - accuracy: 0.8868
67/126 [==============>...............] - ETA: 0s - loss: 0.3088 - accuracy: 0.8871
89/126 [====================>.........] - ETA: 0s - loss: 0.3001 - accuracy: 0.8880
111/126 [=========================>....] - ETA: 0s - loss: 0.3055 - accuracy: 0.8874
126/126 [==============================] - 0s 2ms/step - loss: 0.3065 - accuracy: 0.8871
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
CV Folds: 67%|######6 | 2/3 [00:21<00:11, 11.11s/it](4020, 28, 28)
Epoch 1/10
1/126 [..............................] - ETA: 1:15 - loss: 2.3129 - accuracy: 0.0938
23/126 [====>.........................] - ETA: 0s - loss: 1.4928 - accuracy: 0.4932
45/126 [=========>....................] - ETA: 0s - loss: 1.2259 - accuracy: 0.5847
67/126 [==============>...............] - ETA: 0s - loss: 1.0973 - accuracy: 0.6292
89/126 [====================>.........] - ETA: 0s - loss: 1.0045 - accuracy: 0.6612
111/126 [=========================>....] - ETA: 0s - loss: 0.9344 - accuracy: 0.6850
126/126 [==============================] - 1s 2ms/step - loss: 0.8965 - accuracy: 0.6983
Epoch 2/10
1/126 [..............................] - ETA: 0s - loss: 0.4769 - accuracy: 0.8125
23/126 [====>.........................] - ETA: 0s - loss: 0.6002 - accuracy: 0.7908
45/126 [=========>....................] - ETA: 0s - loss: 0.5949 - accuracy: 0.7958
67/126 [==============>...............] - ETA: 0s - loss: 0.5938 - accuracy: 0.7934
88/126 [===================>..........] - ETA: 0s - loss: 0.5874 - accuracy: 0.7987
110/126 [=========================>....] - ETA: 0s - loss: 0.5706 - accuracy: 0.8045
126/126 [==============================] - 0s 2ms/step - loss: 0.5750 - accuracy: 0.8040
Epoch 3/10
1/126 [..............................] - ETA: 0s - loss: 0.6893 - accuracy: 0.7812
23/126 [====>.........................] - ETA: 0s - loss: 0.5690 - accuracy: 0.8030
45/126 [=========>....................] - ETA: 0s - loss: 0.5393 - accuracy: 0.8118
67/126 [==============>...............] - ETA: 0s - loss: 0.5347 - accuracy: 0.8130
89/126 [====================>.........] - ETA: 0s - loss: 0.5258 - accuracy: 0.8209
111/126 [=========================>....] - ETA: 0s - loss: 0.5107 - accuracy: 0.8257
126/126 [==============================] - 0s 2ms/step - loss: 0.5140 - accuracy: 0.8224
Epoch 4/10
1/126 [..............................] - ETA: 0s - loss: 0.4519 - accuracy: 0.9062
23/126 [====>.........................] - ETA: 0s - loss: 0.4343 - accuracy: 0.8533
45/126 [=========>....................] - ETA: 0s - loss: 0.4467 - accuracy: 0.8424
67/126 [==============>...............] - ETA: 0s - loss: 0.4483 - accuracy: 0.8410
89/126 [====================>.........] - ETA: 0s - loss: 0.4468 - accuracy: 0.8430
111/126 [=========================>....] - ETA: 0s - loss: 0.4613 - accuracy: 0.8395
126/126 [==============================] - 0s 2ms/step - loss: 0.4667 - accuracy: 0.8358
Epoch 5/10
1/126 [..............................] - ETA: 0s - loss: 0.4192 - accuracy: 0.8125
23/126 [====>.........................] - ETA: 0s - loss: 0.4238 - accuracy: 0.8587
45/126 [=========>....................] - ETA: 0s - loss: 0.4053 - accuracy: 0.8625
67/126 [==============>...............] - ETA: 0s - loss: 0.4156 - accuracy: 0.8647
89/126 [====================>.........] - ETA: 0s - loss: 0.4199 - accuracy: 0.8606
111/126 [=========================>....] - ETA: 0s - loss: 0.4200 - accuracy: 0.8595
126/126 [==============================] - 0s 2ms/step - loss: 0.4173 - accuracy: 0.8600
Epoch 6/10
1/126 [..............................] - ETA: 0s - loss: 0.8799 - accuracy: 0.7500
22/126 [====>.........................] - ETA: 0s - loss: 0.3734 - accuracy: 0.8651
43/126 [=========>....................] - ETA: 0s - loss: 0.3809 - accuracy: 0.8677
65/126 [==============>...............] - ETA: 0s - loss: 0.3946 - accuracy: 0.8644
87/126 [===================>..........] - ETA: 0s - loss: 0.3924 - accuracy: 0.8635
109/126 [========================>.....] - ETA: 0s - loss: 0.3936 - accuracy: 0.8627
126/126 [==============================] - 0s 2ms/step - loss: 0.3949 - accuracy: 0.8624
Epoch 7/10
1/126 [..............................] - ETA: 0s - loss: 0.2022 - accuracy: 0.9375
23/126 [====>.........................] - ETA: 0s - loss: 0.3799 - accuracy: 0.8764
45/126 [=========>....................] - ETA: 0s - loss: 0.3484 - accuracy: 0.8806
67/126 [==============>...............] - ETA: 0s - loss: 0.3502 - accuracy: 0.8769
89/126 [====================>.........] - ETA: 0s - loss: 0.3634 - accuracy: 0.8757
111/126 [=========================>....] - ETA: 0s - loss: 0.3648 - accuracy: 0.8767
126/126 [==============================] - 0s 2ms/step - loss: 0.3635 - accuracy: 0.8774
Epoch 8/10
1/126 [..............................] - ETA: 0s - loss: 0.3186 - accuracy: 0.8750
23/126 [====>.........................] - ETA: 0s - loss: 0.3309 - accuracy: 0.8832
45/126 [=========>....................] - ETA: 0s - loss: 0.3316 - accuracy: 0.8854
67/126 [==============>...............] - ETA: 0s - loss: 0.3335 - accuracy: 0.8853
89/126 [====================>.........] - ETA: 0s - loss: 0.3395 - accuracy: 0.8862
111/126 [=========================>....] - ETA: 0s - loss: 0.3380 - accuracy: 0.8832
126/126 [==============================] - 0s 2ms/step - loss: 0.3457 - accuracy: 0.8826
Epoch 9/10
1/126 [..............................] - ETA: 0s - loss: 0.2176 - accuracy: 0.9375
22/126 [====>.........................] - ETA: 0s - loss: 0.3447 - accuracy: 0.8835
44/126 [=========>....................] - ETA: 0s - loss: 0.3421 - accuracy: 0.8771
66/126 [==============>...............] - ETA: 0s - loss: 0.3280 - accuracy: 0.8816
88/126 [===================>..........] - ETA: 0s - loss: 0.3318 - accuracy: 0.8821
110/126 [=========================>....] - ETA: 0s - loss: 0.3295 - accuracy: 0.8875
126/126 [==============================] - 0s 2ms/step - loss: 0.3285 - accuracy: 0.8878
Epoch 10/10
1/126 [..............................] - ETA: 0s - loss: 0.3079 - accuracy: 0.8125
23/126 [====>.........................] - ETA: 0s - loss: 0.2855 - accuracy: 0.9049
45/126 [=========>....................] - ETA: 0s - loss: 0.2990 - accuracy: 0.9014
67/126 [==============>...............] - ETA: 0s - loss: 0.2908 - accuracy: 0.9053
89/126 [====================>.........] - ETA: 0s - loss: 0.2913 - accuracy: 0.9045
111/126 [=========================>....] - ETA: 0s - loss: 0.2977 - accuracy: 0.9006
126/126 [==============================] - 0s 2ms/step - loss: 0.3008 - accuracy: 0.8993
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 1ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 1ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
1/2 [==============>...............] - ETA: 0s
2/2 [==============================] - 0s 2ms/step
CV Folds: 100%|##########| 3/3 [00:34<00:00, 11.67s/it]
CV Folds: 100%|##########| 3/3 [00:34<00:00, 11.40s/it]
We can now look at the results per group:
cv_results["test_single_accuracy"]
[[0.85, 0.8666666666666667, 0.9333333333333333, 0.8, 0.85, 0.85, 0.85, 0.8833333333333333, 0.85, 0.85, 0.8166666666666667, 0.8666666666666667, 0.9333333333333333, 0.8, 0.8833333333333333, 0.8, 0.8, 0.8833333333333333, 0.7833333333333333, 0.8166666666666667, 0.8, 0.85, 0.8, 0.95, 0.8, 0.8833333333333333, 0.8833333333333333, 0.85, 0.85, 0.7833333333333333, 0.75, 0.8666666666666667, 0.8333333333333334, 0.9333333333333333], [0.7833333333333333, 0.8166666666666667, 0.7666666666666667, 0.7833333333333333, 0.85, 0.8, 0.8, 0.8166666666666667, 0.8, 0.8333333333333334, 0.8333333333333334, 0.8, 0.8166666666666667, 0.85, 0.75, 0.7, 0.8333333333333334, 0.8166666666666667, 0.8666666666666667, 0.7666666666666667, 0.7833333333333333, 0.7833333333333333, 0.7666666666666667, 0.8833333333333333, 0.8333333333333334, 0.7833333333333333, 0.8, 0.85, 0.85, 0.8166666666666667, 0.7666666666666667, 0.8, 0.8833333333333333], [0.7666666666666667, 0.9166666666666666, 0.8, 0.8, 0.7833333333333333, 0.85, 0.8833333333333333, 0.85, 0.9, 0.8833333333333333, 0.9, 0.85, 0.8833333333333333, 0.8166666666666667, 0.8333333333333334, 0.8333333333333334, 0.85, 0.7833333333333333, 0.7166666666666667, 0.7333333333333333, 0.75, 0.7833333333333333, 0.85, 0.7666666666666667, 0.7833333333333333, 0.9, 0.8666666666666667, 0.8333333333333334, 0.8333333333333334, 0.8166666666666667, 0.85, 0.85, 0.9]]
And the overall accuracy as the average over all samples of all groups within a fold:
cv_results["test_per_sample__accuracy"]
array([0.84705882, 0.80858586, 0.83080808])
Total running time of the script: ( 0 minutes 39.100 seconds)
Estimated memory usage: 261 MB