Purpose: The previous two notebooks have discussed theoretical and foundational aspects of deep learning models. In particular, what types of architectures and activation functions exist (and, to a lesser extent, how do I choose one). In this notebook our goal will be to actually build, assess, and utilize a deep learning network for image classification.

Data and Modeling

Since you set up TensorFlow in an earlier notebook, let’s load {reticulate} in an R code chunk and then import {tensorflow}, {keras}, using a python chunk.We’ll use the Fashion MNIST data set, so we’ll load that as well. You can learn more about that data set from its official repository here.

library(reticulate)
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from keras.datasets import fashion_mnist
from keras.models import Sequential
from keras.layers import Flatten, Dense, Dropout, Lambda

(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()

x_train = x_train/255
x_test = x_test/255

labels_df = pd.DataFrame({
  "label" : range(10),
  "item" : ["Tshirt", "Trousers", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "AnkleBoot"]
  })

In the code block above, we loaded the Fashion MNIST data, which comes already packaged into training and test sets. We then scaled the pixel densities from integer values (between 0 and 255) to floats. We created a data frame of labels for convenience, since the labels in y_train and y_test are numeric only. Finally, we wrote a function to rotate the matrix of pixel intensities so that the images will be arranged vertically when we plot them – this is important for us humans but of no importance to the neural network we’ll be training.

Let’s take a look at a few items and their labels.

item_num = 4

print("Item is: ", labels_df.loc[labels_df["label"] == y_train[item_num], "item"])
## Item is:  0    Tshirt
## Name: item, dtype: object
plt.imshow(x_train[item_num, :, :])
plt.show()

Okay – I’m having a difficult time identifying some of these items. Can we train a sequential neural network to learn the classes?

model = Sequential([
  Flatten(input_shape = (28,28)),
  Dense(128, activation = "relu"),
  Dropout(0.2),
  Dense(10, activation = "softmax")
])

model.summary()
## Model: "sequential"
## _________________________________________________________________
##  Layer (type)                Output Shape              Param #   
## =================================================================
##  flatten (Flatten)           (None, 784)               0         
##                                                                  
##  dense (Dense)               (None, 128)               100480    
##                                                                  
##  dropout (Dropout)           (None, 128)               0         
##                                                                  
##  dense_1 (Dense)             (None, 10)                1290      
##                                                                  
## =================================================================
## Total params: 101770 (397.54 KB)
## Trainable params: 101770 (397.54 KB)
## Non-trainable params: 0 (0.00 Byte)
## _________________________________________________________________

We have a model with over 100,000 parameters! Because random weights are initially set for each of these, we can use the model straight “out of the box” for prediction. We shouldn’t expect the network to perform very well though.

predictions = model.predict(x_train[:2, :, :])
#Predictions as a vector of log-odds
## 
## 1/1 [==============================] - ETA: 0s
## 1/1 [==============================] - 0s 115ms/step
predictions
#Predictions as class-membership probabilities
## array([[0.10245589, 0.08768913, 0.12370513, 0.18182682, 0.08118267,
##         0.0402923 , 0.23050708, 0.05350573, 0.04104901, 0.05778628],
##        [0.0975403 , 0.11283207, 0.0658805 , 0.2278418 , 0.04386355,
##         0.05245278, 0.11948042, 0.16097698, 0.04762336, 0.07150826]],
##       dtype=float32)
tf.nn.softmax(predictions)
## <tf.Tensor: shape=(2, 10), dtype=float32, numpy=
## array([[0.10006306, 0.0985963 , 0.10221207, 0.10832884, 0.09795687,
##         0.09403217, 0.11373278, 0.0952829 , 0.09410335, 0.09569164],
##        [0.09959807, 0.1011328 , 0.09649421, 0.1134593 , 0.09439291,
##         0.09520716, 0.10180741, 0.10612094, 0.09474847, 0.09703878]],
##       dtype=float32)>

Let’s define a loss function so that we can train the model by optimizing the loss.

y_train_wide = np.zeros((len(y_train), 10))

for i in range(len(y_train)):
  label = y_train[i]
  y_train_wide[i, label] = 1

predictions
## array([[0.10245589, 0.08768913, 0.12370513, 0.18182682, 0.08118267,
##         0.0402923 , 0.23050708, 0.05350573, 0.04104901, 0.05778628],
##        [0.0975403 , 0.11283207, 0.0658805 , 0.2278418 , 0.04386355,
##         0.05245278, 0.11948042, 0.16097698, 0.04762336, 0.07150826]],
##       dtype=float32)
y_train_wide[:2, :]
## array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
##        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
loss_fn = tf.keras.losses.CategoricalCrossentropy()
loss_fn(y_train_wide[:2], predictions)
## <tf.Tensor: shape=(), dtype=float32, numpy=2.5892467>

Before training, we’ll need to set the optimizer, assign the loss function, and define the performance metric. We’ll then compile the model with these attributes.

model.compile(optimizer = "adam", loss = loss_fn, metrics = ["accuracy"])

Since the model has been compiled, we are ready to train it. This network has an output layer consisting of 10 neurons (one for each class). Because of this, we’ll need a response which is of the same shape – that is we need a binary matrix whose rows contain a 1 in exactly one position, denoting the class of the corresponding training record. We have this in y_train_wide.

model.fit(x_train, y_train_wide, epochs = 5)
## Epoch 1/5
## 
##    1/1875 [..............................] - ETA: 11:32 - loss: 2.4843 - accuracy: 0.0938
##   33/1875 [..............................] - ETA: 2s - loss: 1.4345 - accuracy: 0.5227   
##   71/1875 [>.............................] - ETA: 2s - loss: 1.1603 - accuracy: 0.6012
##  109/1875 [>.............................] - ETA: 2s - loss: 1.0225 - accuracy: 0.6471
##  148/1875 [=>............................] - ETA: 2s - loss: 0.9433 - accuracy: 0.6708
##  186/1875 [=>............................] - ETA: 2s - loss: 0.8848 - accuracy: 0.6935
##  224/1875 [==>...........................] - ETA: 2s - loss: 0.8513 - accuracy: 0.7040
##  263/1875 [===>..........................] - ETA: 2s - loss: 0.8229 - accuracy: 0.7114
##  301/1875 [===>..........................] - ETA: 2s - loss: 0.7933 - accuracy: 0.7210
##  339/1875 [====>.........................] - ETA: 2s - loss: 0.7733 - accuracy: 0.7282
##  379/1875 [=====>........................] - ETA: 2s - loss: 0.7546 - accuracy: 0.7343
##  416/1875 [=====>........................] - ETA: 1s - loss: 0.7374 - accuracy: 0.7392
##  453/1875 [======>.......................] - ETA: 1s - loss: 0.7222 - accuracy: 0.7445
##  493/1875 [======>.......................] - ETA: 1s - loss: 0.7078 - accuracy: 0.7497
##  528/1875 [=======>......................] - ETA: 1s - loss: 0.6969 - accuracy: 0.7529
##  562/1875 [=======>......................] - ETA: 1s - loss: 0.6886 - accuracy: 0.7553
##  587/1875 [========>.....................] - ETA: 1s - loss: 0.6833 - accuracy: 0.7571
##  607/1875 [========>.....................] - ETA: 1s - loss: 0.6780 - accuracy: 0.7593
##  629/1875 [=========>....................] - ETA: 1s - loss: 0.6732 - accuracy: 0.7614
##  655/1875 [=========>....................] - ETA: 1s - loss: 0.6665 - accuracy: 0.7632
##  678/1875 [=========>....................] - ETA: 1s - loss: 0.6625 - accuracy: 0.7648
##  700/1875 [==========>...................] - ETA: 1s - loss: 0.6571 - accuracy: 0.7667
##  727/1875 [==========>...................] - ETA: 1s - loss: 0.6514 - accuracy: 0.7686
##  757/1875 [===========>..................] - ETA: 1s - loss: 0.6454 - accuracy: 0.7708
##  781/1875 [===========>..................] - ETA: 1s - loss: 0.6422 - accuracy: 0.7723
##  806/1875 [===========>..................] - ETA: 1s - loss: 0.6395 - accuracy: 0.7737
##  838/1875 [============>.................] - ETA: 1s - loss: 0.6351 - accuracy: 0.7747
##  871/1875 [============>.................] - ETA: 1s - loss: 0.6281 - accuracy: 0.7767
##  904/1875 [=============>................] - ETA: 1s - loss: 0.6235 - accuracy: 0.7780
##  942/1875 [==============>...............] - ETA: 1s - loss: 0.6184 - accuracy: 0.7798
##  980/1875 [==============>...............] - ETA: 1s - loss: 0.6134 - accuracy: 0.7813
## 1017/1875 [===============>..............] - ETA: 1s - loss: 0.6080 - accuracy: 0.7837
## 1052/1875 [===============>..............] - ETA: 1s - loss: 0.6029 - accuracy: 0.7853
## 1085/1875 [================>.............] - ETA: 1s - loss: 0.5994 - accuracy: 0.7860
## 1108/1875 [================>.............] - ETA: 1s - loss: 0.5965 - accuracy: 0.7871
## 1136/1875 [=================>............] - ETA: 1s - loss: 0.5910 - accuracy: 0.7891
## 1174/1875 [=================>............] - ETA: 1s - loss: 0.5880 - accuracy: 0.7902
## 1206/1875 [==================>...........] - ETA: 1s - loss: 0.5836 - accuracy: 0.7919
## 1244/1875 [==================>...........] - ETA: 0s - loss: 0.5812 - accuracy: 0.7928
## 1279/1875 [===================>..........] - ETA: 0s - loss: 0.5792 - accuracy: 0.7935
## 1313/1875 [====================>.........] - ETA: 0s - loss: 0.5758 - accuracy: 0.7947
## 1350/1875 [====================>.........] - ETA: 0s - loss: 0.5740 - accuracy: 0.7952
## 1382/1875 [=====================>........] - ETA: 0s - loss: 0.5706 - accuracy: 0.7966
## 1404/1875 [=====================>........] - ETA: 0s - loss: 0.5694 - accuracy: 0.7973
## 1430/1875 [=====================>........] - ETA: 0s - loss: 0.5677 - accuracy: 0.7979
## 1456/1875 [======================>.......] - ETA: 0s - loss: 0.5652 - accuracy: 0.7986
## 1470/1875 [======================>.......] - ETA: 0s - loss: 0.5645 - accuracy: 0.7991
## 1491/1875 [======================>.......] - ETA: 0s - loss: 0.5631 - accuracy: 0.7998
## 1521/1875 [=======================>......] - ETA: 0s - loss: 0.5615 - accuracy: 0.8003
## 1556/1875 [=======================>......] - ETA: 0s - loss: 0.5585 - accuracy: 0.8013
## 1589/1875 [========================>.....] - ETA: 0s - loss: 0.5560 - accuracy: 0.8023
## 1615/1875 [========================>.....] - ETA: 0s - loss: 0.5544 - accuracy: 0.8030
## 1642/1875 [=========================>....] - ETA: 0s - loss: 0.5518 - accuracy: 0.8039
## 1677/1875 [=========================>....] - ETA: 0s - loss: 0.5491 - accuracy: 0.8048
## 1709/1875 [==========================>...] - ETA: 0s - loss: 0.5463 - accuracy: 0.8058
## 1745/1875 [==========================>...] - ETA: 0s - loss: 0.5436 - accuracy: 0.8068
## 1775/1875 [===========================>..] - ETA: 0s - loss: 0.5425 - accuracy: 0.8073
## 1802/1875 [===========================>..] - ETA: 0s - loss: 0.5407 - accuracy: 0.8081
## 1833/1875 [============================>.] - ETA: 0s - loss: 0.5384 - accuracy: 0.8091
## 1867/1875 [============================>.] - ETA: 0s - loss: 0.5364 - accuracy: 0.8098
## 1875/1875 [==============================] - 3s 2ms/step - loss: 0.5362 - accuracy: 0.8100
## Epoch 2/5
## 
##    1/1875 [..............................] - ETA: 3s - loss: 0.3603 - accuracy: 0.8438
##   20/1875 [..............................] - ETA: 4s - loss: 0.4124 - accuracy: 0.8500
##   47/1875 [..............................] - ETA: 4s - loss: 0.4040 - accuracy: 0.8471
##   78/1875 [>.............................] - ETA: 3s - loss: 0.4157 - accuracy: 0.8486
##  100/1875 [>.............................] - ETA: 3s - loss: 0.4140 - accuracy: 0.8481
##  122/1875 [>.............................] - ETA: 3s - loss: 0.4079 - accuracy: 0.8517
##  150/1875 [=>............................] - ETA: 3s - loss: 0.4098 - accuracy: 0.8527
##  180/1875 [=>............................] - ETA: 3s - loss: 0.4078 - accuracy: 0.8528
##  207/1875 [==>...........................] - ETA: 3s - loss: 0.4096 - accuracy: 0.8527
##  236/1875 [==>...........................] - ETA: 3s - loss: 0.4130 - accuracy: 0.8497
##  266/1875 [===>..........................] - ETA: 3s - loss: 0.4136 - accuracy: 0.8493
##  293/1875 [===>..........................] - ETA: 3s - loss: 0.4125 - accuracy: 0.8494
##  312/1875 [===>..........................] - ETA: 3s - loss: 0.4138 - accuracy: 0.8498
##  332/1875 [====>.........................] - ETA: 3s - loss: 0.4148 - accuracy: 0.8492
##  355/1875 [====>.........................] - ETA: 3s - loss: 0.4141 - accuracy: 0.8501
##  382/1875 [=====>........................] - ETA: 3s - loss: 0.4154 - accuracy: 0.8499
##  405/1875 [=====>........................] - ETA: 2s - loss: 0.4134 - accuracy: 0.8512
##  430/1875 [=====>........................] - ETA: 2s - loss: 0.4137 - accuracy: 0.8509
##  456/1875 [======>.......................] - ETA: 2s - loss: 0.4108 - accuracy: 0.8516
##  482/1875 [======>.......................] - ETA: 2s - loss: 0.4113 - accuracy: 0.8511
##  499/1875 [======>.......................] - ETA: 2s - loss: 0.4121 - accuracy: 0.8507
##  525/1875 [=======>......................] - ETA: 2s - loss: 0.4117 - accuracy: 0.8511
##  554/1875 [=======>......................] - ETA: 2s - loss: 0.4122 - accuracy: 0.8515
##  581/1875 [========>.....................] - ETA: 2s - loss: 0.4136 - accuracy: 0.8515
##  604/1875 [========>.....................] - ETA: 2s - loss: 0.4129 - accuracy: 0.8521
##  630/1875 [=========>....................] - ETA: 2s - loss: 0.4151 - accuracy: 0.8515
##  656/1875 [=========>....................] - ETA: 2s - loss: 0.4134 - accuracy: 0.8526
##  681/1875 [=========>....................] - ETA: 2s - loss: 0.4135 - accuracy: 0.8526
##  707/1875 [==========>...................] - ETA: 2s - loss: 0.4137 - accuracy: 0.8526
##  731/1875 [==========>...................] - ETA: 2s - loss: 0.4122 - accuracy: 0.8531
##  761/1875 [===========>..................] - ETA: 2s - loss: 0.4122 - accuracy: 0.8527
##  790/1875 [===========>..................] - ETA: 2s - loss: 0.4120 - accuracy: 0.8530
##  812/1875 [===========>..................] - ETA: 2s - loss: 0.4126 - accuracy: 0.8530
##  837/1875 [============>.................] - ETA: 2s - loss: 0.4120 - accuracy: 0.8531
##  861/1875 [============>.................] - ETA: 2s - loss: 0.4128 - accuracy: 0.8528
##  887/1875 [=============>................] - ETA: 1s - loss: 0.4121 - accuracy: 0.8529
##  923/1875 [=============>................] - ETA: 1s - loss: 0.4128 - accuracy: 0.8525
##  959/1875 [==============>...............] - ETA: 1s - loss: 0.4118 - accuracy: 0.8525
##  994/1875 [==============>...............] - ETA: 1s - loss: 0.4114 - accuracy: 0.8530
## 1029/1875 [===============>..............] - ETA: 1s - loss: 0.4124 - accuracy: 0.8528
## 1058/1875 [===============>..............] - ETA: 1s - loss: 0.4121 - accuracy: 0.8532
## 1088/1875 [================>.............] - ETA: 1s - loss: 0.4107 - accuracy: 0.8533
## 1116/1875 [================>.............] - ETA: 1s - loss: 0.4108 - accuracy: 0.8532
## 1144/1875 [=================>............] - ETA: 1s - loss: 0.4102 - accuracy: 0.8534
## 1170/1875 [=================>............] - ETA: 1s - loss: 0.4084 - accuracy: 0.8541
## 1197/1875 [==================>...........] - ETA: 1s - loss: 0.4084 - accuracy: 0.8540
## 1221/1875 [==================>...........] - ETA: 1s - loss: 0.4083 - accuracy: 0.8542
## 1247/1875 [==================>...........] - ETA: 1s - loss: 0.4076 - accuracy: 0.8545
## 1280/1875 [===================>..........] - ETA: 1s - loss: 0.4085 - accuracy: 0.8542
## 1312/1875 [===================>..........] - ETA: 1s - loss: 0.4079 - accuracy: 0.8544
## 1337/1875 [====================>.........] - ETA: 1s - loss: 0.4073 - accuracy: 0.8548
## 1368/1875 [====================>.........] - ETA: 0s - loss: 0.4059 - accuracy: 0.8552
## 1402/1875 [=====================>........] - ETA: 0s - loss: 0.4053 - accuracy: 0.8551
## 1436/1875 [=====================>........] - ETA: 0s - loss: 0.4051 - accuracy: 0.8551
## 1470/1875 [======================>.......] - ETA: 0s - loss: 0.4050 - accuracy: 0.8550
## 1510/1875 [=======================>......] - ETA: 0s - loss: 0.4039 - accuracy: 0.8551
## 1539/1875 [=======================>......] - ETA: 0s - loss: 0.4035 - accuracy: 0.8552
## 1566/1875 [========================>.....] - ETA: 0s - loss: 0.4043 - accuracy: 0.8548
## 1596/1875 [========================>.....] - ETA: 0s - loss: 0.4039 - accuracy: 0.8550
## 1622/1875 [========================>.....] - ETA: 0s - loss: 0.4041 - accuracy: 0.8551
## 1647/1875 [=========================>....] - ETA: 0s - loss: 0.4038 - accuracy: 0.8552
## 1672/1875 [=========================>....] - ETA: 0s - loss: 0.4037 - accuracy: 0.8550
## 1698/1875 [==========================>...] - ETA: 0s - loss: 0.4036 - accuracy: 0.8550
## 1720/1875 [==========================>...] - ETA: 0s - loss: 0.4034 - accuracy: 0.8551
## 1745/1875 [==========================>...] - ETA: 0s - loss: 0.4039 - accuracy: 0.8551
## 1768/1875 [===========================>..] - ETA: 0s - loss: 0.4037 - accuracy: 0.8554
## 1792/1875 [===========================>..] - ETA: 0s - loss: 0.4035 - accuracy: 0.8553
## 1817/1875 [============================>.] - ETA: 0s - loss: 0.4033 - accuracy: 0.8553
## 1842/1875 [============================>.] - ETA: 0s - loss: 0.4027 - accuracy: 0.8553
## 1869/1875 [============================>.] - ETA: 0s - loss: 0.4022 - accuracy: 0.8556
## 1875/1875 [==============================] - 4s 2ms/step - loss: 0.4022 - accuracy: 0.8556
## Epoch 3/5
## 
##    1/1875 [..............................] - ETA: 2s - loss: 0.4119 - accuracy: 0.9062
##   28/1875 [..............................] - ETA: 3s - loss: 0.3770 - accuracy: 0.8616
##   55/1875 [..............................] - ETA: 3s - loss: 0.3570 - accuracy: 0.8716
##   84/1875 [>.............................] - ETA: 3s - loss: 0.3503 - accuracy: 0.8739
##  114/1875 [>.............................] - ETA: 3s - loss: 0.3558 - accuracy: 0.8717
##  141/1875 [=>............................] - ETA: 3s - loss: 0.3542 - accuracy: 0.8712
##  169/1875 [=>............................] - ETA: 3s - loss: 0.3514 - accuracy: 0.8724
##  197/1875 [==>...........................] - ETA: 3s - loss: 0.3553 - accuracy: 0.8706
##  229/1875 [==>...........................] - ETA: 2s - loss: 0.3577 - accuracy: 0.8684
##  264/1875 [===>..........................] - ETA: 2s - loss: 0.3596 - accuracy: 0.8685
##  300/1875 [===>..........................] - ETA: 2s - loss: 0.3622 - accuracy: 0.8685
##  332/1875 [====>.........................] - ETA: 2s - loss: 0.3622 - accuracy: 0.8685
##  362/1875 [====>.........................] - ETA: 2s - loss: 0.3645 - accuracy: 0.8677
##  398/1875 [=====>........................] - ETA: 2s - loss: 0.3624 - accuracy: 0.8685
##  428/1875 [=====>........................] - ETA: 2s - loss: 0.3631 - accuracy: 0.8690
##  465/1875 [======>.......................] - ETA: 2s - loss: 0.3659 - accuracy: 0.8674
##  501/1875 [=======>......................] - ETA: 2s - loss: 0.3675 - accuracy: 0.8667
##  536/1875 [=======>......................] - ETA: 2s - loss: 0.3699 - accuracy: 0.8662
##  573/1875 [========>.....................] - ETA: 2s - loss: 0.3697 - accuracy: 0.8660
##  610/1875 [========>.....................] - ETA: 2s - loss: 0.3716 - accuracy: 0.8645
##  647/1875 [=========>....................] - ETA: 1s - loss: 0.3699 - accuracy: 0.8645
##  685/1875 [=========>....................] - ETA: 1s - loss: 0.3706 - accuracy: 0.8639
##  715/1875 [==========>...................] - ETA: 1s - loss: 0.3685 - accuracy: 0.8644
##  748/1875 [==========>...................] - ETA: 1s - loss: 0.3691 - accuracy: 0.8642
##  772/1875 [===========>..................] - ETA: 1s - loss: 0.3689 - accuracy: 0.8642
##  800/1875 [===========>..................] - ETA: 1s - loss: 0.3703 - accuracy: 0.8634
##  828/1875 [============>.................] - ETA: 1s - loss: 0.3696 - accuracy: 0.8632
##  855/1875 [============>.................] - ETA: 1s - loss: 0.3695 - accuracy: 0.8635
##  880/1875 [=============>................] - ETA: 1s - loss: 0.3697 - accuracy: 0.8637
##  904/1875 [=============>................] - ETA: 1s - loss: 0.3703 - accuracy: 0.8636
##  931/1875 [=============>................] - ETA: 1s - loss: 0.3705 - accuracy: 0.8633
##  959/1875 [==============>...............] - ETA: 1s - loss: 0.3700 - accuracy: 0.8633
##  987/1875 [==============>...............] - ETA: 1s - loss: 0.3703 - accuracy: 0.8631
## 1003/1875 [===============>..............] - ETA: 1s - loss: 0.3695 - accuracy: 0.8633
## 1022/1875 [===============>..............] - ETA: 1s - loss: 0.3694 - accuracy: 0.8635
## 1046/1875 [===============>..............] - ETA: 1s - loss: 0.3687 - accuracy: 0.8640
## 1070/1875 [================>.............] - ETA: 1s - loss: 0.3681 - accuracy: 0.8644
## 1098/1875 [================>.............] - ETA: 1s - loss: 0.3673 - accuracy: 0.8647
## 1125/1875 [=================>............] - ETA: 1s - loss: 0.3675 - accuracy: 0.8644
## 1152/1875 [=================>............] - ETA: 1s - loss: 0.3661 - accuracy: 0.8649
## 1179/1875 [=================>............] - ETA: 1s - loss: 0.3660 - accuracy: 0.8651
## 1208/1875 [==================>...........] - ETA: 1s - loss: 0.3656 - accuracy: 0.8650
## 1239/1875 [==================>...........] - ETA: 1s - loss: 0.3648 - accuracy: 0.8654
## 1268/1875 [===================>..........] - ETA: 1s - loss: 0.3651 - accuracy: 0.8656
## 1302/1875 [===================>..........] - ETA: 0s - loss: 0.3659 - accuracy: 0.8655
## 1339/1875 [====================>.........] - ETA: 0s - loss: 0.3667 - accuracy: 0.8653
## 1373/1875 [====================>.........] - ETA: 0s - loss: 0.3659 - accuracy: 0.8655
## 1401/1875 [=====================>........] - ETA: 0s - loss: 0.3663 - accuracy: 0.8654
## 1427/1875 [=====================>........] - ETA: 0s - loss: 0.3659 - accuracy: 0.8655
## 1453/1875 [======================>.......] - ETA: 0s - loss: 0.3659 - accuracy: 0.8654
## 1485/1875 [======================>.......] - ETA: 0s - loss: 0.3651 - accuracy: 0.8657
## 1521/1875 [=======================>......] - ETA: 0s - loss: 0.3658 - accuracy: 0.8654
## 1559/1875 [=======================>......] - ETA: 0s - loss: 0.3662 - accuracy: 0.8654
## 1596/1875 [========================>.....] - ETA: 0s - loss: 0.3664 - accuracy: 0.8655
## 1631/1875 [=========================>....] - ETA: 0s - loss: 0.3670 - accuracy: 0.8652
## 1663/1875 [=========================>....] - ETA: 0s - loss: 0.3664 - accuracy: 0.8655
## 1686/1875 [=========================>....] - ETA: 0s - loss: 0.3656 - accuracy: 0.8659
## 1709/1875 [==========================>...] - ETA: 0s - loss: 0.3660 - accuracy: 0.8657
## 1736/1875 [==========================>...] - ETA: 0s - loss: 0.3661 - accuracy: 0.8656
## 1767/1875 [===========================>..] - ETA: 0s - loss: 0.3663 - accuracy: 0.8656
## 1794/1875 [===========================>..] - ETA: 0s - loss: 0.3664 - accuracy: 0.8657
## 1821/1875 [============================>.] - ETA: 0s - loss: 0.3662 - accuracy: 0.8658
## 1848/1875 [============================>.] - ETA: 0s - loss: 0.3669 - accuracy: 0.8655
## 1875/1875 [==============================] - ETA: 0s - loss: 0.3668 - accuracy: 0.8654
## 1875/1875 [==============================] - 3s 2ms/step - loss: 0.3668 - accuracy: 0.8654
## Epoch 4/5
## 
##    1/1875 [..............................] - ETA: 5s - loss: 0.2090 - accuracy: 0.8750
##   27/1875 [..............................] - ETA: 3s - loss: 0.3525 - accuracy: 0.8692
##   56/1875 [..............................] - ETA: 3s - loss: 0.3639 - accuracy: 0.8655
##   87/1875 [>.............................] - ETA: 3s - loss: 0.3475 - accuracy: 0.8736
##  120/1875 [>.............................] - ETA: 2s - loss: 0.3668 - accuracy: 0.8711
##  157/1875 [=>............................] - ETA: 2s - loss: 0.3627 - accuracy: 0.8724
##  193/1875 [==>...........................] - ETA: 2s - loss: 0.3601 - accuracy: 0.8713
##  229/1875 [==>...........................] - ETA: 2s - loss: 0.3587 - accuracy: 0.8706
##  267/1875 [===>..........................] - ETA: 2s - loss: 0.3546 - accuracy: 0.8728
##  302/1875 [===>..........................] - ETA: 2s - loss: 0.3515 - accuracy: 0.8745
##  339/1875 [====>.........................] - ETA: 2s - loss: 0.3519 - accuracy: 0.8733
##  376/1875 [=====>........................] - ETA: 2s - loss: 0.3507 - accuracy: 0.8728
##  410/1875 [=====>........................] - ETA: 2s - loss: 0.3515 - accuracy: 0.8729
##  446/1875 [======>.......................] - ETA: 2s - loss: 0.3510 - accuracy: 0.8728
##  477/1875 [======>.......................] - ETA: 2s - loss: 0.3509 - accuracy: 0.8721
##  506/1875 [=======>......................] - ETA: 2s - loss: 0.3503 - accuracy: 0.8723
##  542/1875 [=======>......................] - ETA: 1s - loss: 0.3534 - accuracy: 0.8711
##  581/1875 [========>.....................] - ETA: 1s - loss: 0.3535 - accuracy: 0.8710
##  619/1875 [========>.....................] - ETA: 1s - loss: 0.3556 - accuracy: 0.8697
##  654/1875 [=========>....................] - ETA: 1s - loss: 0.3537 - accuracy: 0.8703
##  693/1875 [==========>...................] - ETA: 1s - loss: 0.3533 - accuracy: 0.8704
##  731/1875 [==========>...................] - ETA: 1s - loss: 0.3527 - accuracy: 0.8709
##  770/1875 [===========>..................] - ETA: 1s - loss: 0.3507 - accuracy: 0.8715
##  809/1875 [===========>..................] - ETA: 1s - loss: 0.3484 - accuracy: 0.8721
##  847/1875 [============>.................] - ETA: 1s - loss: 0.3478 - accuracy: 0.8723
##  881/1875 [=============>................] - ETA: 1s - loss: 0.3471 - accuracy: 0.8725
##  920/1875 [=============>................] - ETA: 1s - loss: 0.3467 - accuracy: 0.8729
##  959/1875 [==============>...............] - ETA: 1s - loss: 0.3465 - accuracy: 0.8731
##  996/1875 [==============>...............] - ETA: 1s - loss: 0.3473 - accuracy: 0.8731
## 1032/1875 [===============>..............] - ETA: 1s - loss: 0.3458 - accuracy: 0.8734
## 1063/1875 [================>.............] - ETA: 1s - loss: 0.3465 - accuracy: 0.8730
## 1099/1875 [================>.............] - ETA: 1s - loss: 0.3459 - accuracy: 0.8736
## 1136/1875 [=================>............] - ETA: 1s - loss: 0.3456 - accuracy: 0.8734
## 1173/1875 [=================>............] - ETA: 1s - loss: 0.3466 - accuracy: 0.8736
## 1209/1875 [==================>...........] - ETA: 0s - loss: 0.3458 - accuracy: 0.8738
## 1243/1875 [==================>...........] - ETA: 0s - loss: 0.3455 - accuracy: 0.8738
## 1281/1875 [===================>..........] - ETA: 0s - loss: 0.3456 - accuracy: 0.8739
## 1319/1875 [====================>.........] - ETA: 0s - loss: 0.3466 - accuracy: 0.8737
## 1359/1875 [====================>.........] - ETA: 0s - loss: 0.3472 - accuracy: 0.8737
## 1397/1875 [=====================>........] - ETA: 0s - loss: 0.3463 - accuracy: 0.8741
## 1433/1875 [=====================>........] - ETA: 0s - loss: 0.3461 - accuracy: 0.8741
## 1468/1875 [======================>.......] - ETA: 0s - loss: 0.3444 - accuracy: 0.8744
## 1502/1875 [=======================>......] - ETA: 0s - loss: 0.3438 - accuracy: 0.8746
## 1538/1875 [=======================>......] - ETA: 0s - loss: 0.3440 - accuracy: 0.8745
## 1575/1875 [========================>.....] - ETA: 0s - loss: 0.3443 - accuracy: 0.8745
## 1612/1875 [========================>.....] - ETA: 0s - loss: 0.3450 - accuracy: 0.8742
## 1649/1875 [=========================>....] - ETA: 0s - loss: 0.3448 - accuracy: 0.8744
## 1685/1875 [=========================>....] - ETA: 0s - loss: 0.3455 - accuracy: 0.8741
## 1722/1875 [==========================>...] - ETA: 0s - loss: 0.3455 - accuracy: 0.8742
## 1758/1875 [===========================>..] - ETA: 0s - loss: 0.3455 - accuracy: 0.8743
## 1791/1875 [===========================>..] - ETA: 0s - loss: 0.3464 - accuracy: 0.8740
## 1813/1875 [============================>.] - ETA: 0s - loss: 0.3461 - accuracy: 0.8740
## 1841/1875 [============================>.] - ETA: 0s - loss: 0.3458 - accuracy: 0.8738
## 1873/1875 [============================>.] - ETA: 0s - loss: 0.3453 - accuracy: 0.8739
## 1875/1875 [==============================] - 3s 1ms/step - loss: 0.3454 - accuracy: 0.8740
## Epoch 5/5
## 
##    1/1875 [..............................] - ETA: 3s - loss: 0.4103 - accuracy: 0.8750
##   29/1875 [..............................] - ETA: 3s - loss: 0.3199 - accuracy: 0.8922
##   58/1875 [..............................] - ETA: 3s - loss: 0.3315 - accuracy: 0.8874
##   88/1875 [>.............................] - ETA: 3s - loss: 0.3287 - accuracy: 0.8881
##  123/1875 [>.............................] - ETA: 2s - loss: 0.3249 - accuracy: 0.8882
##  160/1875 [=>............................] - ETA: 2s - loss: 0.3243 - accuracy: 0.8846
##  197/1875 [==>...........................] - ETA: 2s - loss: 0.3307 - accuracy: 0.8829
##  234/1875 [==>...........................] - ETA: 2s - loss: 0.3309 - accuracy: 0.8818
##  273/1875 [===>..........................] - ETA: 2s - loss: 0.3330 - accuracy: 0.8804
##  311/1875 [===>..........................] - ETA: 2s - loss: 0.3361 - accuracy: 0.8803
##  350/1875 [====>.........................] - ETA: 2s - loss: 0.3383 - accuracy: 0.8802
##  387/1875 [=====>........................] - ETA: 2s - loss: 0.3351 - accuracy: 0.8807
##  423/1875 [=====>........................] - ETA: 2s - loss: 0.3323 - accuracy: 0.8811
##  458/1875 [======>.......................] - ETA: 2s - loss: 0.3329 - accuracy: 0.8803
##  490/1875 [======>.......................] - ETA: 2s - loss: 0.3310 - accuracy: 0.8804
##  521/1875 [=======>......................] - ETA: 1s - loss: 0.3275 - accuracy: 0.8817
##  544/1875 [=======>......................] - ETA: 1s - loss: 0.3282 - accuracy: 0.8811
##  565/1875 [========>.....................] - ETA: 2s - loss: 0.3292 - accuracy: 0.8806
##  585/1875 [========>.....................] - ETA: 2s - loss: 0.3290 - accuracy: 0.8801
##  611/1875 [========>.....................] - ETA: 2s - loss: 0.3307 - accuracy: 0.8792
##  640/1875 [=========>....................] - ETA: 1s - loss: 0.3299 - accuracy: 0.8794
##  670/1875 [=========>....................] - ETA: 1s - loss: 0.3281 - accuracy: 0.8794
##  699/1875 [==========>...................] - ETA: 1s - loss: 0.3302 - accuracy: 0.8786
##  730/1875 [==========>...................] - ETA: 1s - loss: 0.3312 - accuracy: 0.8779
##  755/1875 [===========>..................] - ETA: 1s - loss: 0.3318 - accuracy: 0.8782
##  781/1875 [===========>..................] - ETA: 1s - loss: 0.3316 - accuracy: 0.8788
##  806/1875 [===========>..................] - ETA: 1s - loss: 0.3307 - accuracy: 0.8792
##  827/1875 [============>.................] - ETA: 1s - loss: 0.3320 - accuracy: 0.8785
##  846/1875 [============>.................] - ETA: 1s - loss: 0.3331 - accuracy: 0.8777
##  870/1875 [============>.................] - ETA: 1s - loss: 0.3326 - accuracy: 0.8782
##  902/1875 [=============>................] - ETA: 1s - loss: 0.3324 - accuracy: 0.8783
##  930/1875 [=============>................] - ETA: 1s - loss: 0.3315 - accuracy: 0.8784
##  955/1875 [==============>...............] - ETA: 1s - loss: 0.3311 - accuracy: 0.8783
##  986/1875 [==============>...............] - ETA: 1s - loss: 0.3316 - accuracy: 0.8780
## 1021/1875 [===============>..............] - ETA: 1s - loss: 0.3319 - accuracy: 0.8778
## 1060/1875 [===============>..............] - ETA: 1s - loss: 0.3311 - accuracy: 0.8781
## 1099/1875 [================>.............] - ETA: 1s - loss: 0.3310 - accuracy: 0.8782
## 1128/1875 [=================>............] - ETA: 1s - loss: 0.3316 - accuracy: 0.8778
## 1162/1875 [=================>............] - ETA: 1s - loss: 0.3308 - accuracy: 0.8781
## 1196/1875 [==================>...........] - ETA: 1s - loss: 0.3311 - accuracy: 0.8783
## 1226/1875 [==================>...........] - ETA: 1s - loss: 0.3316 - accuracy: 0.8780
## 1255/1875 [===================>..........] - ETA: 1s - loss: 0.3322 - accuracy: 0.8777
## 1283/1875 [===================>..........] - ETA: 0s - loss: 0.3322 - accuracy: 0.8777
## 1307/1875 [===================>..........] - ETA: 0s - loss: 0.3324 - accuracy: 0.8774
## 1338/1875 [====================>.........] - ETA: 0s - loss: 0.3324 - accuracy: 0.8773
## 1365/1875 [====================>.........] - ETA: 0s - loss: 0.3329 - accuracy: 0.8771
## 1401/1875 [=====================>........] - ETA: 0s - loss: 0.3330 - accuracy: 0.8769
## 1438/1875 [======================>.......] - ETA: 0s - loss: 0.3334 - accuracy: 0.8767
## 1477/1875 [======================>.......] - ETA: 0s - loss: 0.3349 - accuracy: 0.8762
## 1515/1875 [=======================>......] - ETA: 0s - loss: 0.3343 - accuracy: 0.8762
## 1553/1875 [=======================>......] - ETA: 0s - loss: 0.3342 - accuracy: 0.8764
## 1591/1875 [========================>.....] - ETA: 0s - loss: 0.3338 - accuracy: 0.8764
## 1629/1875 [=========================>....] - ETA: 0s - loss: 0.3333 - accuracy: 0.8765
## 1667/1875 [=========================>....] - ETA: 0s - loss: 0.3330 - accuracy: 0.8765
## 1706/1875 [==========================>...] - ETA: 0s - loss: 0.3332 - accuracy: 0.8768
## 1745/1875 [==========================>...] - ETA: 0s - loss: 0.3334 - accuracy: 0.8767
## 1784/1875 [===========================>..] - ETA: 0s - loss: 0.3335 - accuracy: 0.8767
## 1823/1875 [============================>.] - ETA: 0s - loss: 0.3326 - accuracy: 0.8771
## 1862/1875 [============================>.] - ETA: 0s - loss: 0.3323 - accuracy: 0.8770
## 1875/1875 [==============================] - 3s 2ms/step - loss: 0.3326 - accuracy: 0.8769
## <keras.src.callbacks.History object at 0x000001A3DA576B30>

Now let’s evaluate our model performance.

y_test_wide = np.zeros((len(y_test), 10))

for i in range(len(y_test)):
  label = y_test[i]
  y_test_wide[i, label] = 1

score = model.evaluate(x_test, y_test_wide)
## 
##   1/313 [..............................] - ETA: 42s - loss: 0.4358 - accuracy: 0.8438
##  45/313 [===>..........................] - ETA: 0s - loss: 0.3562 - accuracy: 0.8701 
##  93/313 [=======>......................] - ETA: 0s - loss: 0.3606 - accuracy: 0.8666
## 137/313 [============>.................] - ETA: 0s - loss: 0.3720 - accuracy: 0.8634
## 183/313 [================>.............] - ETA: 0s - loss: 0.3828 - accuracy: 0.8612
## 226/313 [====================>.........] - ETA: 0s - loss: 0.3786 - accuracy: 0.8626
## 269/313 [========================>.....] - ETA: 0s - loss: 0.3705 - accuracy: 0.8655
## 313/313 [==============================] - 0s 1ms/step - loss: 0.3695 - accuracy: 0.8650
score
## [0.3695288896560669, 0.8650000095367432]

We got 88% accuracy with a pretty vanilla and shallow neural network. There was only one hidden layer here, with 20% dropout. We didn’t tune any model hyperparameters and only trained over 5 epochs. We can see that loss was continuing to decrease and accuracy was continuing to climb from one epoch to the next here.

Since our model has been trained, we can use it to make predictions again.

predictions = model.predict(x_test[1:5, :, :])
## 
## 1/1 [==============================] - ETA: 0s
## 1/1 [==============================] - 0s 56ms/step
tf.nn.softmax(predictions)
## <tf.Tensor: shape=(4, 10), dtype=float32, numpy=
## array([[0.08541211, 0.08541188, 0.2307762 , 0.08541188, 0.08544528,
##         0.08541188, 0.0858952 , 0.08541188, 0.08541188, 0.08541188],
##        [0.08533675, 0.2319693 , 0.08533674, 0.08533675, 0.08533674,
##         0.08533674, 0.08533674, 0.08533674, 0.08533674, 0.08533674],
##        [0.08533683, 0.23196799, 0.08533682, 0.08533739, 0.08533682,
##         0.08533682, 0.08533682, 0.08533682, 0.08533682, 0.08533682],
##        [0.09722714, 0.08779798, 0.09837112, 0.08825789, 0.08933015,
##         0.08779713, 0.18763044, 0.08779696, 0.08799417, 0.087797  ]],
##       dtype=float32)>

We can update our model so that it will provide class predictions rather than just the class membership probabilities.

class_model = Sequential([
  model,
  Lambda(lambda x: tf.argmax(x, axis = -1))
])

class_model.predict(x_test[1:5, :, :])
## 
## 1/1 [==============================] - ETA: 0s
## 1/1 [==============================] - 0s 35ms/step
## array([2, 1, 1, 6], dtype=int64)

Now that we’ve trained an assessed one neural network, go back and change your model. Add hidden layers to make it a true deep learning model. Experiment with the dropout rate or activation functions. Just remember that you’ll need 10 neurons in your output layer since we have 10 classes and that the activation function used there should remain softmax since we are working on a multiclass classification problem. Everything else (other than the input shape) is fair game to change though!

Summary

In this notebook we installed and used TensorFlow from R to build and assess a shallow learning network to classify clothing items from very pixelated images. The images were \(28\times 28\). We saw that even a “simple” neural network was much better at predicting the class of an item based off of its pixelated image than we are as humans.