Purpose: The previous two notebooks have discussed theoretical and foundational aspects of deep learning models. In particular, what types of architectures and activation functions exist (and, to a lesser extent, how do I choose one). In this notebook our goal will be to actually build, assess, and utilize a deep learning network for image classification.
Since you set up TensorFlow in an earlier notebook, let’s load
{reticulate}
in an R code chunk and then import
{tensorflow}
, {keras}
, using a python
chunk.We’ll use the Fashion MNIST data set, so we’ll load that as well.
You can learn more about that data set from its official
repository here.
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from keras.datasets import fashion_mnist
from keras.models import Sequential
from keras.layers import Flatten, Dense, Dropout, Lambda
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
x_train = x_train/255
x_test = x_test/255
labels_df = pd.DataFrame({
"label" : range(10),
"item" : ["Tshirt", "Trousers", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "AnkleBoot"]
})
In the code block above, we loaded the Fashion MNIST data, which
comes already packaged into training and test sets. We then scaled the
pixel densities from integer values (between 0 and 255) to floats. We
created a data frame of labels for convenience, since the labels in
y_train
and y_test
are numeric only. Finally,
we wrote a function to rotate the matrix of pixel intensities so that
the images will be arranged vertically when we plot them – this is
important for us humans but of no importance to the neural network we’ll
be training.
Let’s take a look at a few items and their labels.
## Item is: 0 Tshirt
## Name: item, dtype: object
Okay – I’m having a difficult time identifying some of these items. Can we train a sequential neural network to learn the classes?
model = Sequential([
Flatten(input_shape = (28,28)),
Dense(128, activation = "relu"),
Dropout(0.2),
Dense(10, activation = "softmax")
])
model.summary()
## Model: "sequential"
## _________________________________________________________________
## Layer (type) Output Shape Param #
## =================================================================
## flatten (Flatten) (None, 784) 0
##
## dense (Dense) (None, 128) 100480
##
## dropout (Dropout) (None, 128) 0
##
## dense_1 (Dense) (None, 10) 1290
##
## =================================================================
## Total params: 101770 (397.54 KB)
## Trainable params: 101770 (397.54 KB)
## Non-trainable params: 0 (0.00 Byte)
## _________________________________________________________________
We have a model with over 100,000 parameters! Because random weights are initially set for each of these, we can use the model straight “out of the box” for prediction. We shouldn’t expect the network to perform very well though.
##
## 1/1 [==============================] - ETA: 0s
## 1/1 [==============================] - 0s 115ms/step
## array([[0.10245589, 0.08768913, 0.12370513, 0.18182682, 0.08118267,
## 0.0402923 , 0.23050708, 0.05350573, 0.04104901, 0.05778628],
## [0.0975403 , 0.11283207, 0.0658805 , 0.2278418 , 0.04386355,
## 0.05245278, 0.11948042, 0.16097698, 0.04762336, 0.07150826]],
## dtype=float32)
## <tf.Tensor: shape=(2, 10), dtype=float32, numpy=
## array([[0.10006306, 0.0985963 , 0.10221207, 0.10832884, 0.09795687,
## 0.09403217, 0.11373278, 0.0952829 , 0.09410335, 0.09569164],
## [0.09959807, 0.1011328 , 0.09649421, 0.1134593 , 0.09439291,
## 0.09520716, 0.10180741, 0.10612094, 0.09474847, 0.09703878]],
## dtype=float32)>
Let’s define a loss function so that we can train the model by optimizing the loss.
y_train_wide = np.zeros((len(y_train), 10))
for i in range(len(y_train)):
label = y_train[i]
y_train_wide[i, label] = 1
predictions
## array([[0.10245589, 0.08768913, 0.12370513, 0.18182682, 0.08118267,
## 0.0402923 , 0.23050708, 0.05350573, 0.04104901, 0.05778628],
## [0.0975403 , 0.11283207, 0.0658805 , 0.2278418 , 0.04386355,
## 0.05245278, 0.11948042, 0.16097698, 0.04762336, 0.07150826]],
## dtype=float32)
## array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
## [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
## <tf.Tensor: shape=(), dtype=float32, numpy=2.5892467>
Before training, we’ll need to set the optimizer, assign the loss function, and define the performance metric. We’ll then compile the model with these attributes.
Since the model has been compiled, we are ready to train it. This
network has an output layer consisting of 10 neurons (one for each
class). Because of this, we’ll need a response which is of the same
shape – that is we need a binary matrix whose rows contain a 1 in
exactly one position, denoting the class of the corresponding training
record. We have this in y_train_wide
.
## Epoch 1/5
##
## 1/1875 [..............................] - ETA: 11:32 - loss: 2.4843 - accuracy: 0.0938
## 33/1875 [..............................] - ETA: 2s - loss: 1.4345 - accuracy: 0.5227
## 71/1875 [>.............................] - ETA: 2s - loss: 1.1603 - accuracy: 0.6012
## 109/1875 [>.............................] - ETA: 2s - loss: 1.0225 - accuracy: 0.6471
## 148/1875 [=>............................] - ETA: 2s - loss: 0.9433 - accuracy: 0.6708
## 186/1875 [=>............................] - ETA: 2s - loss: 0.8848 - accuracy: 0.6935
## 224/1875 [==>...........................] - ETA: 2s - loss: 0.8513 - accuracy: 0.7040
## 263/1875 [===>..........................] - ETA: 2s - loss: 0.8229 - accuracy: 0.7114
## 301/1875 [===>..........................] - ETA: 2s - loss: 0.7933 - accuracy: 0.7210
## 339/1875 [====>.........................] - ETA: 2s - loss: 0.7733 - accuracy: 0.7282
## 379/1875 [=====>........................] - ETA: 2s - loss: 0.7546 - accuracy: 0.7343
## 416/1875 [=====>........................] - ETA: 1s - loss: 0.7374 - accuracy: 0.7392
## 453/1875 [======>.......................] - ETA: 1s - loss: 0.7222 - accuracy: 0.7445
## 493/1875 [======>.......................] - ETA: 1s - loss: 0.7078 - accuracy: 0.7497
## 528/1875 [=======>......................] - ETA: 1s - loss: 0.6969 - accuracy: 0.7529
## 562/1875 [=======>......................] - ETA: 1s - loss: 0.6886 - accuracy: 0.7553
## 587/1875 [========>.....................] - ETA: 1s - loss: 0.6833 - accuracy: 0.7571
## 607/1875 [========>.....................] - ETA: 1s - loss: 0.6780 - accuracy: 0.7593
## 629/1875 [=========>....................] - ETA: 1s - loss: 0.6732 - accuracy: 0.7614
## 655/1875 [=========>....................] - ETA: 1s - loss: 0.6665 - accuracy: 0.7632
## 678/1875 [=========>....................] - ETA: 1s - loss: 0.6625 - accuracy: 0.7648
## 700/1875 [==========>...................] - ETA: 1s - loss: 0.6571 - accuracy: 0.7667
## 727/1875 [==========>...................] - ETA: 1s - loss: 0.6514 - accuracy: 0.7686
## 757/1875 [===========>..................] - ETA: 1s - loss: 0.6454 - accuracy: 0.7708
## 781/1875 [===========>..................] - ETA: 1s - loss: 0.6422 - accuracy: 0.7723
## 806/1875 [===========>..................] - ETA: 1s - loss: 0.6395 - accuracy: 0.7737
## 838/1875 [============>.................] - ETA: 1s - loss: 0.6351 - accuracy: 0.7747
## 871/1875 [============>.................] - ETA: 1s - loss: 0.6281 - accuracy: 0.7767
## 904/1875 [=============>................] - ETA: 1s - loss: 0.6235 - accuracy: 0.7780
## 942/1875 [==============>...............] - ETA: 1s - loss: 0.6184 - accuracy: 0.7798
## 980/1875 [==============>...............] - ETA: 1s - loss: 0.6134 - accuracy: 0.7813
## 1017/1875 [===============>..............] - ETA: 1s - loss: 0.6080 - accuracy: 0.7837
## 1052/1875 [===============>..............] - ETA: 1s - loss: 0.6029 - accuracy: 0.7853
## 1085/1875 [================>.............] - ETA: 1s - loss: 0.5994 - accuracy: 0.7860
## 1108/1875 [================>.............] - ETA: 1s - loss: 0.5965 - accuracy: 0.7871
## 1136/1875 [=================>............] - ETA: 1s - loss: 0.5910 - accuracy: 0.7891
## 1174/1875 [=================>............] - ETA: 1s - loss: 0.5880 - accuracy: 0.7902
## 1206/1875 [==================>...........] - ETA: 1s - loss: 0.5836 - accuracy: 0.7919
## 1244/1875 [==================>...........] - ETA: 0s - loss: 0.5812 - accuracy: 0.7928
## 1279/1875 [===================>..........] - ETA: 0s - loss: 0.5792 - accuracy: 0.7935
## 1313/1875 [====================>.........] - ETA: 0s - loss: 0.5758 - accuracy: 0.7947
## 1350/1875 [====================>.........] - ETA: 0s - loss: 0.5740 - accuracy: 0.7952
## 1382/1875 [=====================>........] - ETA: 0s - loss: 0.5706 - accuracy: 0.7966
## 1404/1875 [=====================>........] - ETA: 0s - loss: 0.5694 - accuracy: 0.7973
## 1430/1875 [=====================>........] - ETA: 0s - loss: 0.5677 - accuracy: 0.7979
## 1456/1875 [======================>.......] - ETA: 0s - loss: 0.5652 - accuracy: 0.7986
## 1470/1875 [======================>.......] - ETA: 0s - loss: 0.5645 - accuracy: 0.7991
## 1491/1875 [======================>.......] - ETA: 0s - loss: 0.5631 - accuracy: 0.7998
## 1521/1875 [=======================>......] - ETA: 0s - loss: 0.5615 - accuracy: 0.8003
## 1556/1875 [=======================>......] - ETA: 0s - loss: 0.5585 - accuracy: 0.8013
## 1589/1875 [========================>.....] - ETA: 0s - loss: 0.5560 - accuracy: 0.8023
## 1615/1875 [========================>.....] - ETA: 0s - loss: 0.5544 - accuracy: 0.8030
## 1642/1875 [=========================>....] - ETA: 0s - loss: 0.5518 - accuracy: 0.8039
## 1677/1875 [=========================>....] - ETA: 0s - loss: 0.5491 - accuracy: 0.8048
## 1709/1875 [==========================>...] - ETA: 0s - loss: 0.5463 - accuracy: 0.8058
## 1745/1875 [==========================>...] - ETA: 0s - loss: 0.5436 - accuracy: 0.8068
## 1775/1875 [===========================>..] - ETA: 0s - loss: 0.5425 - accuracy: 0.8073
## 1802/1875 [===========================>..] - ETA: 0s - loss: 0.5407 - accuracy: 0.8081
## 1833/1875 [============================>.] - ETA: 0s - loss: 0.5384 - accuracy: 0.8091
## 1867/1875 [============================>.] - ETA: 0s - loss: 0.5364 - accuracy: 0.8098
## 1875/1875 [==============================] - 3s 2ms/step - loss: 0.5362 - accuracy: 0.8100
## Epoch 2/5
##
## 1/1875 [..............................] - ETA: 3s - loss: 0.3603 - accuracy: 0.8438
## 20/1875 [..............................] - ETA: 4s - loss: 0.4124 - accuracy: 0.8500
## 47/1875 [..............................] - ETA: 4s - loss: 0.4040 - accuracy: 0.8471
## 78/1875 [>.............................] - ETA: 3s - loss: 0.4157 - accuracy: 0.8486
## 100/1875 [>.............................] - ETA: 3s - loss: 0.4140 - accuracy: 0.8481
## 122/1875 [>.............................] - ETA: 3s - loss: 0.4079 - accuracy: 0.8517
## 150/1875 [=>............................] - ETA: 3s - loss: 0.4098 - accuracy: 0.8527
## 180/1875 [=>............................] - ETA: 3s - loss: 0.4078 - accuracy: 0.8528
## 207/1875 [==>...........................] - ETA: 3s - loss: 0.4096 - accuracy: 0.8527
## 236/1875 [==>...........................] - ETA: 3s - loss: 0.4130 - accuracy: 0.8497
## 266/1875 [===>..........................] - ETA: 3s - loss: 0.4136 - accuracy: 0.8493
## 293/1875 [===>..........................] - ETA: 3s - loss: 0.4125 - accuracy: 0.8494
## 312/1875 [===>..........................] - ETA: 3s - loss: 0.4138 - accuracy: 0.8498
## 332/1875 [====>.........................] - ETA: 3s - loss: 0.4148 - accuracy: 0.8492
## 355/1875 [====>.........................] - ETA: 3s - loss: 0.4141 - accuracy: 0.8501
## 382/1875 [=====>........................] - ETA: 3s - loss: 0.4154 - accuracy: 0.8499
## 405/1875 [=====>........................] - ETA: 2s - loss: 0.4134 - accuracy: 0.8512
## 430/1875 [=====>........................] - ETA: 2s - loss: 0.4137 - accuracy: 0.8509
## 456/1875 [======>.......................] - ETA: 2s - loss: 0.4108 - accuracy: 0.8516
## 482/1875 [======>.......................] - ETA: 2s - loss: 0.4113 - accuracy: 0.8511
## 499/1875 [======>.......................] - ETA: 2s - loss: 0.4121 - accuracy: 0.8507
## 525/1875 [=======>......................] - ETA: 2s - loss: 0.4117 - accuracy: 0.8511
## 554/1875 [=======>......................] - ETA: 2s - loss: 0.4122 - accuracy: 0.8515
## 581/1875 [========>.....................] - ETA: 2s - loss: 0.4136 - accuracy: 0.8515
## 604/1875 [========>.....................] - ETA: 2s - loss: 0.4129 - accuracy: 0.8521
## 630/1875 [=========>....................] - ETA: 2s - loss: 0.4151 - accuracy: 0.8515
## 656/1875 [=========>....................] - ETA: 2s - loss: 0.4134 - accuracy: 0.8526
## 681/1875 [=========>....................] - ETA: 2s - loss: 0.4135 - accuracy: 0.8526
## 707/1875 [==========>...................] - ETA: 2s - loss: 0.4137 - accuracy: 0.8526
## 731/1875 [==========>...................] - ETA: 2s - loss: 0.4122 - accuracy: 0.8531
## 761/1875 [===========>..................] - ETA: 2s - loss: 0.4122 - accuracy: 0.8527
## 790/1875 [===========>..................] - ETA: 2s - loss: 0.4120 - accuracy: 0.8530
## 812/1875 [===========>..................] - ETA: 2s - loss: 0.4126 - accuracy: 0.8530
## 837/1875 [============>.................] - ETA: 2s - loss: 0.4120 - accuracy: 0.8531
## 861/1875 [============>.................] - ETA: 2s - loss: 0.4128 - accuracy: 0.8528
## 887/1875 [=============>................] - ETA: 1s - loss: 0.4121 - accuracy: 0.8529
## 923/1875 [=============>................] - ETA: 1s - loss: 0.4128 - accuracy: 0.8525
## 959/1875 [==============>...............] - ETA: 1s - loss: 0.4118 - accuracy: 0.8525
## 994/1875 [==============>...............] - ETA: 1s - loss: 0.4114 - accuracy: 0.8530
## 1029/1875 [===============>..............] - ETA: 1s - loss: 0.4124 - accuracy: 0.8528
## 1058/1875 [===============>..............] - ETA: 1s - loss: 0.4121 - accuracy: 0.8532
## 1088/1875 [================>.............] - ETA: 1s - loss: 0.4107 - accuracy: 0.8533
## 1116/1875 [================>.............] - ETA: 1s - loss: 0.4108 - accuracy: 0.8532
## 1144/1875 [=================>............] - ETA: 1s - loss: 0.4102 - accuracy: 0.8534
## 1170/1875 [=================>............] - ETA: 1s - loss: 0.4084 - accuracy: 0.8541
## 1197/1875 [==================>...........] - ETA: 1s - loss: 0.4084 - accuracy: 0.8540
## 1221/1875 [==================>...........] - ETA: 1s - loss: 0.4083 - accuracy: 0.8542
## 1247/1875 [==================>...........] - ETA: 1s - loss: 0.4076 - accuracy: 0.8545
## 1280/1875 [===================>..........] - ETA: 1s - loss: 0.4085 - accuracy: 0.8542
## 1312/1875 [===================>..........] - ETA: 1s - loss: 0.4079 - accuracy: 0.8544
## 1337/1875 [====================>.........] - ETA: 1s - loss: 0.4073 - accuracy: 0.8548
## 1368/1875 [====================>.........] - ETA: 0s - loss: 0.4059 - accuracy: 0.8552
## 1402/1875 [=====================>........] - ETA: 0s - loss: 0.4053 - accuracy: 0.8551
## 1436/1875 [=====================>........] - ETA: 0s - loss: 0.4051 - accuracy: 0.8551
## 1470/1875 [======================>.......] - ETA: 0s - loss: 0.4050 - accuracy: 0.8550
## 1510/1875 [=======================>......] - ETA: 0s - loss: 0.4039 - accuracy: 0.8551
## 1539/1875 [=======================>......] - ETA: 0s - loss: 0.4035 - accuracy: 0.8552
## 1566/1875 [========================>.....] - ETA: 0s - loss: 0.4043 - accuracy: 0.8548
## 1596/1875 [========================>.....] - ETA: 0s - loss: 0.4039 - accuracy: 0.8550
## 1622/1875 [========================>.....] - ETA: 0s - loss: 0.4041 - accuracy: 0.8551
## 1647/1875 [=========================>....] - ETA: 0s - loss: 0.4038 - accuracy: 0.8552
## 1672/1875 [=========================>....] - ETA: 0s - loss: 0.4037 - accuracy: 0.8550
## 1698/1875 [==========================>...] - ETA: 0s - loss: 0.4036 - accuracy: 0.8550
## 1720/1875 [==========================>...] - ETA: 0s - loss: 0.4034 - accuracy: 0.8551
## 1745/1875 [==========================>...] - ETA: 0s - loss: 0.4039 - accuracy: 0.8551
## 1768/1875 [===========================>..] - ETA: 0s - loss: 0.4037 - accuracy: 0.8554
## 1792/1875 [===========================>..] - ETA: 0s - loss: 0.4035 - accuracy: 0.8553
## 1817/1875 [============================>.] - ETA: 0s - loss: 0.4033 - accuracy: 0.8553
## 1842/1875 [============================>.] - ETA: 0s - loss: 0.4027 - accuracy: 0.8553
## 1869/1875 [============================>.] - ETA: 0s - loss: 0.4022 - accuracy: 0.8556
## 1875/1875 [==============================] - 4s 2ms/step - loss: 0.4022 - accuracy: 0.8556
## Epoch 3/5
##
## 1/1875 [..............................] - ETA: 2s - loss: 0.4119 - accuracy: 0.9062
## 28/1875 [..............................] - ETA: 3s - loss: 0.3770 - accuracy: 0.8616
## 55/1875 [..............................] - ETA: 3s - loss: 0.3570 - accuracy: 0.8716
## 84/1875 [>.............................] - ETA: 3s - loss: 0.3503 - accuracy: 0.8739
## 114/1875 [>.............................] - ETA: 3s - loss: 0.3558 - accuracy: 0.8717
## 141/1875 [=>............................] - ETA: 3s - loss: 0.3542 - accuracy: 0.8712
## 169/1875 [=>............................] - ETA: 3s - loss: 0.3514 - accuracy: 0.8724
## 197/1875 [==>...........................] - ETA: 3s - loss: 0.3553 - accuracy: 0.8706
## 229/1875 [==>...........................] - ETA: 2s - loss: 0.3577 - accuracy: 0.8684
## 264/1875 [===>..........................] - ETA: 2s - loss: 0.3596 - accuracy: 0.8685
## 300/1875 [===>..........................] - ETA: 2s - loss: 0.3622 - accuracy: 0.8685
## 332/1875 [====>.........................] - ETA: 2s - loss: 0.3622 - accuracy: 0.8685
## 362/1875 [====>.........................] - ETA: 2s - loss: 0.3645 - accuracy: 0.8677
## 398/1875 [=====>........................] - ETA: 2s - loss: 0.3624 - accuracy: 0.8685
## 428/1875 [=====>........................] - ETA: 2s - loss: 0.3631 - accuracy: 0.8690
## 465/1875 [======>.......................] - ETA: 2s - loss: 0.3659 - accuracy: 0.8674
## 501/1875 [=======>......................] - ETA: 2s - loss: 0.3675 - accuracy: 0.8667
## 536/1875 [=======>......................] - ETA: 2s - loss: 0.3699 - accuracy: 0.8662
## 573/1875 [========>.....................] - ETA: 2s - loss: 0.3697 - accuracy: 0.8660
## 610/1875 [========>.....................] - ETA: 2s - loss: 0.3716 - accuracy: 0.8645
## 647/1875 [=========>....................] - ETA: 1s - loss: 0.3699 - accuracy: 0.8645
## 685/1875 [=========>....................] - ETA: 1s - loss: 0.3706 - accuracy: 0.8639
## 715/1875 [==========>...................] - ETA: 1s - loss: 0.3685 - accuracy: 0.8644
## 748/1875 [==========>...................] - ETA: 1s - loss: 0.3691 - accuracy: 0.8642
## 772/1875 [===========>..................] - ETA: 1s - loss: 0.3689 - accuracy: 0.8642
## 800/1875 [===========>..................] - ETA: 1s - loss: 0.3703 - accuracy: 0.8634
## 828/1875 [============>.................] - ETA: 1s - loss: 0.3696 - accuracy: 0.8632
## 855/1875 [============>.................] - ETA: 1s - loss: 0.3695 - accuracy: 0.8635
## 880/1875 [=============>................] - ETA: 1s - loss: 0.3697 - accuracy: 0.8637
## 904/1875 [=============>................] - ETA: 1s - loss: 0.3703 - accuracy: 0.8636
## 931/1875 [=============>................] - ETA: 1s - loss: 0.3705 - accuracy: 0.8633
## 959/1875 [==============>...............] - ETA: 1s - loss: 0.3700 - accuracy: 0.8633
## 987/1875 [==============>...............] - ETA: 1s - loss: 0.3703 - accuracy: 0.8631
## 1003/1875 [===============>..............] - ETA: 1s - loss: 0.3695 - accuracy: 0.8633
## 1022/1875 [===============>..............] - ETA: 1s - loss: 0.3694 - accuracy: 0.8635
## 1046/1875 [===============>..............] - ETA: 1s - loss: 0.3687 - accuracy: 0.8640
## 1070/1875 [================>.............] - ETA: 1s - loss: 0.3681 - accuracy: 0.8644
## 1098/1875 [================>.............] - ETA: 1s - loss: 0.3673 - accuracy: 0.8647
## 1125/1875 [=================>............] - ETA: 1s - loss: 0.3675 - accuracy: 0.8644
## 1152/1875 [=================>............] - ETA: 1s - loss: 0.3661 - accuracy: 0.8649
## 1179/1875 [=================>............] - ETA: 1s - loss: 0.3660 - accuracy: 0.8651
## 1208/1875 [==================>...........] - ETA: 1s - loss: 0.3656 - accuracy: 0.8650
## 1239/1875 [==================>...........] - ETA: 1s - loss: 0.3648 - accuracy: 0.8654
## 1268/1875 [===================>..........] - ETA: 1s - loss: 0.3651 - accuracy: 0.8656
## 1302/1875 [===================>..........] - ETA: 0s - loss: 0.3659 - accuracy: 0.8655
## 1339/1875 [====================>.........] - ETA: 0s - loss: 0.3667 - accuracy: 0.8653
## 1373/1875 [====================>.........] - ETA: 0s - loss: 0.3659 - accuracy: 0.8655
## 1401/1875 [=====================>........] - ETA: 0s - loss: 0.3663 - accuracy: 0.8654
## 1427/1875 [=====================>........] - ETA: 0s - loss: 0.3659 - accuracy: 0.8655
## 1453/1875 [======================>.......] - ETA: 0s - loss: 0.3659 - accuracy: 0.8654
## 1485/1875 [======================>.......] - ETA: 0s - loss: 0.3651 - accuracy: 0.8657
## 1521/1875 [=======================>......] - ETA: 0s - loss: 0.3658 - accuracy: 0.8654
## 1559/1875 [=======================>......] - ETA: 0s - loss: 0.3662 - accuracy: 0.8654
## 1596/1875 [========================>.....] - ETA: 0s - loss: 0.3664 - accuracy: 0.8655
## 1631/1875 [=========================>....] - ETA: 0s - loss: 0.3670 - accuracy: 0.8652
## 1663/1875 [=========================>....] - ETA: 0s - loss: 0.3664 - accuracy: 0.8655
## 1686/1875 [=========================>....] - ETA: 0s - loss: 0.3656 - accuracy: 0.8659
## 1709/1875 [==========================>...] - ETA: 0s - loss: 0.3660 - accuracy: 0.8657
## 1736/1875 [==========================>...] - ETA: 0s - loss: 0.3661 - accuracy: 0.8656
## 1767/1875 [===========================>..] - ETA: 0s - loss: 0.3663 - accuracy: 0.8656
## 1794/1875 [===========================>..] - ETA: 0s - loss: 0.3664 - accuracy: 0.8657
## 1821/1875 [============================>.] - ETA: 0s - loss: 0.3662 - accuracy: 0.8658
## 1848/1875 [============================>.] - ETA: 0s - loss: 0.3669 - accuracy: 0.8655
## 1875/1875 [==============================] - ETA: 0s - loss: 0.3668 - accuracy: 0.8654
## 1875/1875 [==============================] - 3s 2ms/step - loss: 0.3668 - accuracy: 0.8654
## Epoch 4/5
##
## 1/1875 [..............................] - ETA: 5s - loss: 0.2090 - accuracy: 0.8750
## 27/1875 [..............................] - ETA: 3s - loss: 0.3525 - accuracy: 0.8692
## 56/1875 [..............................] - ETA: 3s - loss: 0.3639 - accuracy: 0.8655
## 87/1875 [>.............................] - ETA: 3s - loss: 0.3475 - accuracy: 0.8736
## 120/1875 [>.............................] - ETA: 2s - loss: 0.3668 - accuracy: 0.8711
## 157/1875 [=>............................] - ETA: 2s - loss: 0.3627 - accuracy: 0.8724
## 193/1875 [==>...........................] - ETA: 2s - loss: 0.3601 - accuracy: 0.8713
## 229/1875 [==>...........................] - ETA: 2s - loss: 0.3587 - accuracy: 0.8706
## 267/1875 [===>..........................] - ETA: 2s - loss: 0.3546 - accuracy: 0.8728
## 302/1875 [===>..........................] - ETA: 2s - loss: 0.3515 - accuracy: 0.8745
## 339/1875 [====>.........................] - ETA: 2s - loss: 0.3519 - accuracy: 0.8733
## 376/1875 [=====>........................] - ETA: 2s - loss: 0.3507 - accuracy: 0.8728
## 410/1875 [=====>........................] - ETA: 2s - loss: 0.3515 - accuracy: 0.8729
## 446/1875 [======>.......................] - ETA: 2s - loss: 0.3510 - accuracy: 0.8728
## 477/1875 [======>.......................] - ETA: 2s - loss: 0.3509 - accuracy: 0.8721
## 506/1875 [=======>......................] - ETA: 2s - loss: 0.3503 - accuracy: 0.8723
## 542/1875 [=======>......................] - ETA: 1s - loss: 0.3534 - accuracy: 0.8711
## 581/1875 [========>.....................] - ETA: 1s - loss: 0.3535 - accuracy: 0.8710
## 619/1875 [========>.....................] - ETA: 1s - loss: 0.3556 - accuracy: 0.8697
## 654/1875 [=========>....................] - ETA: 1s - loss: 0.3537 - accuracy: 0.8703
## 693/1875 [==========>...................] - ETA: 1s - loss: 0.3533 - accuracy: 0.8704
## 731/1875 [==========>...................] - ETA: 1s - loss: 0.3527 - accuracy: 0.8709
## 770/1875 [===========>..................] - ETA: 1s - loss: 0.3507 - accuracy: 0.8715
## 809/1875 [===========>..................] - ETA: 1s - loss: 0.3484 - accuracy: 0.8721
## 847/1875 [============>.................] - ETA: 1s - loss: 0.3478 - accuracy: 0.8723
## 881/1875 [=============>................] - ETA: 1s - loss: 0.3471 - accuracy: 0.8725
## 920/1875 [=============>................] - ETA: 1s - loss: 0.3467 - accuracy: 0.8729
## 959/1875 [==============>...............] - ETA: 1s - loss: 0.3465 - accuracy: 0.8731
## 996/1875 [==============>...............] - ETA: 1s - loss: 0.3473 - accuracy: 0.8731
## 1032/1875 [===============>..............] - ETA: 1s - loss: 0.3458 - accuracy: 0.8734
## 1063/1875 [================>.............] - ETA: 1s - loss: 0.3465 - accuracy: 0.8730
## 1099/1875 [================>.............] - ETA: 1s - loss: 0.3459 - accuracy: 0.8736
## 1136/1875 [=================>............] - ETA: 1s - loss: 0.3456 - accuracy: 0.8734
## 1173/1875 [=================>............] - ETA: 1s - loss: 0.3466 - accuracy: 0.8736
## 1209/1875 [==================>...........] - ETA: 0s - loss: 0.3458 - accuracy: 0.8738
## 1243/1875 [==================>...........] - ETA: 0s - loss: 0.3455 - accuracy: 0.8738
## 1281/1875 [===================>..........] - ETA: 0s - loss: 0.3456 - accuracy: 0.8739
## 1319/1875 [====================>.........] - ETA: 0s - loss: 0.3466 - accuracy: 0.8737
## 1359/1875 [====================>.........] - ETA: 0s - loss: 0.3472 - accuracy: 0.8737
## 1397/1875 [=====================>........] - ETA: 0s - loss: 0.3463 - accuracy: 0.8741
## 1433/1875 [=====================>........] - ETA: 0s - loss: 0.3461 - accuracy: 0.8741
## 1468/1875 [======================>.......] - ETA: 0s - loss: 0.3444 - accuracy: 0.8744
## 1502/1875 [=======================>......] - ETA: 0s - loss: 0.3438 - accuracy: 0.8746
## 1538/1875 [=======================>......] - ETA: 0s - loss: 0.3440 - accuracy: 0.8745
## 1575/1875 [========================>.....] - ETA: 0s - loss: 0.3443 - accuracy: 0.8745
## 1612/1875 [========================>.....] - ETA: 0s - loss: 0.3450 - accuracy: 0.8742
## 1649/1875 [=========================>....] - ETA: 0s - loss: 0.3448 - accuracy: 0.8744
## 1685/1875 [=========================>....] - ETA: 0s - loss: 0.3455 - accuracy: 0.8741
## 1722/1875 [==========================>...] - ETA: 0s - loss: 0.3455 - accuracy: 0.8742
## 1758/1875 [===========================>..] - ETA: 0s - loss: 0.3455 - accuracy: 0.8743
## 1791/1875 [===========================>..] - ETA: 0s - loss: 0.3464 - accuracy: 0.8740
## 1813/1875 [============================>.] - ETA: 0s - loss: 0.3461 - accuracy: 0.8740
## 1841/1875 [============================>.] - ETA: 0s - loss: 0.3458 - accuracy: 0.8738
## 1873/1875 [============================>.] - ETA: 0s - loss: 0.3453 - accuracy: 0.8739
## 1875/1875 [==============================] - 3s 1ms/step - loss: 0.3454 - accuracy: 0.8740
## Epoch 5/5
##
## 1/1875 [..............................] - ETA: 3s - loss: 0.4103 - accuracy: 0.8750
## 29/1875 [..............................] - ETA: 3s - loss: 0.3199 - accuracy: 0.8922
## 58/1875 [..............................] - ETA: 3s - loss: 0.3315 - accuracy: 0.8874
## 88/1875 [>.............................] - ETA: 3s - loss: 0.3287 - accuracy: 0.8881
## 123/1875 [>.............................] - ETA: 2s - loss: 0.3249 - accuracy: 0.8882
## 160/1875 [=>............................] - ETA: 2s - loss: 0.3243 - accuracy: 0.8846
## 197/1875 [==>...........................] - ETA: 2s - loss: 0.3307 - accuracy: 0.8829
## 234/1875 [==>...........................] - ETA: 2s - loss: 0.3309 - accuracy: 0.8818
## 273/1875 [===>..........................] - ETA: 2s - loss: 0.3330 - accuracy: 0.8804
## 311/1875 [===>..........................] - ETA: 2s - loss: 0.3361 - accuracy: 0.8803
## 350/1875 [====>.........................] - ETA: 2s - loss: 0.3383 - accuracy: 0.8802
## 387/1875 [=====>........................] - ETA: 2s - loss: 0.3351 - accuracy: 0.8807
## 423/1875 [=====>........................] - ETA: 2s - loss: 0.3323 - accuracy: 0.8811
## 458/1875 [======>.......................] - ETA: 2s - loss: 0.3329 - accuracy: 0.8803
## 490/1875 [======>.......................] - ETA: 2s - loss: 0.3310 - accuracy: 0.8804
## 521/1875 [=======>......................] - ETA: 1s - loss: 0.3275 - accuracy: 0.8817
## 544/1875 [=======>......................] - ETA: 1s - loss: 0.3282 - accuracy: 0.8811
## 565/1875 [========>.....................] - ETA: 2s - loss: 0.3292 - accuracy: 0.8806
## 585/1875 [========>.....................] - ETA: 2s - loss: 0.3290 - accuracy: 0.8801
## 611/1875 [========>.....................] - ETA: 2s - loss: 0.3307 - accuracy: 0.8792
## 640/1875 [=========>....................] - ETA: 1s - loss: 0.3299 - accuracy: 0.8794
## 670/1875 [=========>....................] - ETA: 1s - loss: 0.3281 - accuracy: 0.8794
## 699/1875 [==========>...................] - ETA: 1s - loss: 0.3302 - accuracy: 0.8786
## 730/1875 [==========>...................] - ETA: 1s - loss: 0.3312 - accuracy: 0.8779
## 755/1875 [===========>..................] - ETA: 1s - loss: 0.3318 - accuracy: 0.8782
## 781/1875 [===========>..................] - ETA: 1s - loss: 0.3316 - accuracy: 0.8788
## 806/1875 [===========>..................] - ETA: 1s - loss: 0.3307 - accuracy: 0.8792
## 827/1875 [============>.................] - ETA: 1s - loss: 0.3320 - accuracy: 0.8785
## 846/1875 [============>.................] - ETA: 1s - loss: 0.3331 - accuracy: 0.8777
## 870/1875 [============>.................] - ETA: 1s - loss: 0.3326 - accuracy: 0.8782
## 902/1875 [=============>................] - ETA: 1s - loss: 0.3324 - accuracy: 0.8783
## 930/1875 [=============>................] - ETA: 1s - loss: 0.3315 - accuracy: 0.8784
## 955/1875 [==============>...............] - ETA: 1s - loss: 0.3311 - accuracy: 0.8783
## 986/1875 [==============>...............] - ETA: 1s - loss: 0.3316 - accuracy: 0.8780
## 1021/1875 [===============>..............] - ETA: 1s - loss: 0.3319 - accuracy: 0.8778
## 1060/1875 [===============>..............] - ETA: 1s - loss: 0.3311 - accuracy: 0.8781
## 1099/1875 [================>.............] - ETA: 1s - loss: 0.3310 - accuracy: 0.8782
## 1128/1875 [=================>............] - ETA: 1s - loss: 0.3316 - accuracy: 0.8778
## 1162/1875 [=================>............] - ETA: 1s - loss: 0.3308 - accuracy: 0.8781
## 1196/1875 [==================>...........] - ETA: 1s - loss: 0.3311 - accuracy: 0.8783
## 1226/1875 [==================>...........] - ETA: 1s - loss: 0.3316 - accuracy: 0.8780
## 1255/1875 [===================>..........] - ETA: 1s - loss: 0.3322 - accuracy: 0.8777
## 1283/1875 [===================>..........] - ETA: 0s - loss: 0.3322 - accuracy: 0.8777
## 1307/1875 [===================>..........] - ETA: 0s - loss: 0.3324 - accuracy: 0.8774
## 1338/1875 [====================>.........] - ETA: 0s - loss: 0.3324 - accuracy: 0.8773
## 1365/1875 [====================>.........] - ETA: 0s - loss: 0.3329 - accuracy: 0.8771
## 1401/1875 [=====================>........] - ETA: 0s - loss: 0.3330 - accuracy: 0.8769
## 1438/1875 [======================>.......] - ETA: 0s - loss: 0.3334 - accuracy: 0.8767
## 1477/1875 [======================>.......] - ETA: 0s - loss: 0.3349 - accuracy: 0.8762
## 1515/1875 [=======================>......] - ETA: 0s - loss: 0.3343 - accuracy: 0.8762
## 1553/1875 [=======================>......] - ETA: 0s - loss: 0.3342 - accuracy: 0.8764
## 1591/1875 [========================>.....] - ETA: 0s - loss: 0.3338 - accuracy: 0.8764
## 1629/1875 [=========================>....] - ETA: 0s - loss: 0.3333 - accuracy: 0.8765
## 1667/1875 [=========================>....] - ETA: 0s - loss: 0.3330 - accuracy: 0.8765
## 1706/1875 [==========================>...] - ETA: 0s - loss: 0.3332 - accuracy: 0.8768
## 1745/1875 [==========================>...] - ETA: 0s - loss: 0.3334 - accuracy: 0.8767
## 1784/1875 [===========================>..] - ETA: 0s - loss: 0.3335 - accuracy: 0.8767
## 1823/1875 [============================>.] - ETA: 0s - loss: 0.3326 - accuracy: 0.8771
## 1862/1875 [============================>.] - ETA: 0s - loss: 0.3323 - accuracy: 0.8770
## 1875/1875 [==============================] - 3s 2ms/step - loss: 0.3326 - accuracy: 0.8769
## <keras.src.callbacks.History object at 0x000001A3DA576B30>
Now let’s evaluate our model performance.
y_test_wide = np.zeros((len(y_test), 10))
for i in range(len(y_test)):
label = y_test[i]
y_test_wide[i, label] = 1
score = model.evaluate(x_test, y_test_wide)
##
## 1/313 [..............................] - ETA: 42s - loss: 0.4358 - accuracy: 0.8438
## 45/313 [===>..........................] - ETA: 0s - loss: 0.3562 - accuracy: 0.8701
## 93/313 [=======>......................] - ETA: 0s - loss: 0.3606 - accuracy: 0.8666
## 137/313 [============>.................] - ETA: 0s - loss: 0.3720 - accuracy: 0.8634
## 183/313 [================>.............] - ETA: 0s - loss: 0.3828 - accuracy: 0.8612
## 226/313 [====================>.........] - ETA: 0s - loss: 0.3786 - accuracy: 0.8626
## 269/313 [========================>.....] - ETA: 0s - loss: 0.3705 - accuracy: 0.8655
## 313/313 [==============================] - 0s 1ms/step - loss: 0.3695 - accuracy: 0.8650
## [0.3695288896560669, 0.8650000095367432]
We got 88% accuracy with a pretty vanilla and shallow neural network. There was only one hidden layer here, with 20% dropout. We didn’t tune any model hyperparameters and only trained over 5 epochs. We can see that loss was continuing to decrease and accuracy was continuing to climb from one epoch to the next here.
Since our model has been trained, we can use it to make predictions again.
##
## 1/1 [==============================] - ETA: 0s
## 1/1 [==============================] - 0s 56ms/step
## <tf.Tensor: shape=(4, 10), dtype=float32, numpy=
## array([[0.08541211, 0.08541188, 0.2307762 , 0.08541188, 0.08544528,
## 0.08541188, 0.0858952 , 0.08541188, 0.08541188, 0.08541188],
## [0.08533675, 0.2319693 , 0.08533674, 0.08533675, 0.08533674,
## 0.08533674, 0.08533674, 0.08533674, 0.08533674, 0.08533674],
## [0.08533683, 0.23196799, 0.08533682, 0.08533739, 0.08533682,
## 0.08533682, 0.08533682, 0.08533682, 0.08533682, 0.08533682],
## [0.09722714, 0.08779798, 0.09837112, 0.08825789, 0.08933015,
## 0.08779713, 0.18763044, 0.08779696, 0.08799417, 0.087797 ]],
## dtype=float32)>
We can update our model so that it will provide class predictions rather than just the class membership probabilities.
class_model = Sequential([
model,
Lambda(lambda x: tf.argmax(x, axis = -1))
])
class_model.predict(x_test[1:5, :, :])
##
## 1/1 [==============================] - ETA: 0s
## 1/1 [==============================] - 0s 35ms/step
## array([2, 1, 1, 6], dtype=int64)
Now that we’ve trained an assessed one neural network, go back and
change your model. Add hidden layers to make it a true deep learning
model. Experiment with the dropout rate or activation functions. Just
remember that you’ll need 10 neurons in your output layer since we have
10 classes and that the activation function used there should remain
softmax
since we are working on a multiclass classification
problem. Everything else (other than the input shape) is fair game to
change though!
In this notebook we installed and used TensorFlow from R to build and assess a shallow learning network to classify clothing items from very pixelated images. The images were \(28\times 28\). We saw that even a “simple” neural network was much better at predicting the class of an item based off of its pixelated image than we are as humans.