Transfer Learning

August 15, 2022

![](https://d3i8kmg9ozyieo.cloudfront.net/transfer-learning-0.png)
Photo by Stephen Dawson on Unsplash

Transfer Learning

Long gone are the days in which data practitioners trained machine learning models from scratch themselves. Unless you have a very specific use case, you’re better off leveraging the pre-trained models that are made available in public repositories such as Tensorflow Hub and Hugging Face. The process by which we customize a pre-trained model for a given task is known as transfer learning.

There are two techniques under the umbrella of transfer learning:

  • Feature extraction: Uses the output (prior to the sigmoid function) from a pre-trained model as the features to a new classifier. When training the new model, you keep the weights of the pre-trained model fixed.
  • Fine-tuning: Unfreezes the weights in the layers of the base model and jointly trains them along with the ones in the newly-added classifier layers.

In this article, we will walkthrough an example of feature extraction. If you’d like to learn more about fine-tuning, you can checkout this article on the subject.

Python

We’ll be using the code taken from the TF Hub for TF2: Retraining an image classifier tutorial since I couldn’t make it any clearer myself.

I recommend using Google Collab to train and run the model since it already has all the dependencies installed. You can view what GPU your runtime has access to as follows:

! nvidia-smi

We begin by importing the following libraries:

import tensorflow as tf  
import tensorflow_hub as hub
import matplotlib.pylab as plt  
import numpy as np

We will use the MobileNetV2 architecture trained on the ImageNet dataset as the base model.

model_handle = "[https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/feature_vector/4](https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/feature_vector/4)"

In TensorFlow Hub, you can download classification and feature vector models. The feature vector models are specifically designed for transfer learning. As the name implies, their output is a feature vector (i.e. not piped through a sigmoid function). Thus, you don’t need to remove the output layer of the model when doing transfer learning. You can simply add additional trainable layers.

![](https://d3i8kmg9ozyieo.cloudfront.net/transfer-learning-1.png)

The model was trained using images of size 224 by 224 pixels. We set a sensible default value for the batch size.

IMAGE_SIZE = (224, 224)  
BATCH_SIZE = 16

The Keras library provides a utility function for retrieving files from the GCP.

data_dir = tf.keras.utils.get_file(  
    'flower_photos',  
    '[https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz'](https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz%27),  
    untar=True)

Using the images we just downloaded, we construct the training and validation datasets. We perform data augmentation by transforming the images slightly (i.e. rotating, translating, flipping).

def build_dataset(subset):  
  return tf.keras.preprocessing.image_dataset_from_directory(  
      data_dir,  
      validation_split=.20,  
      subset=subset,  
      label_mode="categorical",  
      # Seed needs to provided when using validation_split and shuffle = True.  
      # A fixed seed is used so that the validation set is stable across runs.  
      seed=123,  
      image_size=IMAGE_SIZE,  
      batch_size=1)
train_ds = build_dataset("training")  
class_names = tuple(train_ds.class_names)  
train_size = train_ds.cardinality().numpy()  
train_ds = train_ds.unbatch().batch(BATCH_SIZE)  
train_ds = train_ds.repeat()
normalization_layer = tf.keras.layers.Rescaling(1. / 255)  
preprocessing_model = tf.keras.Sequential([normalization_layer])  
do_data_augmentation = False #@param {type:"boolean"}  
if do_data_augmentation:  
  preprocessing_model.add(  
      tf.keras.layers.RandomRotation(40))  
  preprocessing_model.add(  
      tf.keras.layers.RandomTranslation(0, 0.2))  
  preprocessing_model.add(  
      tf.keras.layers.RandomTranslation(0.2, 0))  
  # Like the old tf.keras.preprocessing.image.ImageDataGenerator(),  
  # image sizes are fixed when reading, and then a random zoom is applied.  
  # If all training inputs are larger than image_size, one could also use  
  # RandomCrop with a batch size of 1 and rebatch later.  
  preprocessing_model.add(  
      tf.keras.layers.RandomZoom(0.2, 0.2))  
  preprocessing_model.add(  
      tf.keras.layers.RandomFlip(mode="horizontal"))  
train_ds = train_ds.map(lambda images, labels:  
                        (preprocessing_model(images), labels))
val_ds = build_dataset("validation")  
valid_size = val_ds.cardinality().numpy()  
val_ds = val_ds.unbatch().batch(BATCH_SIZE)  
val_ds = val_ds.map(lambda images, labels:  
                    (normalization_layer(images), labels))

We freeze the weights of the MobileNetV2 model.

do_fine_tuning = False

We add a Dropout layer (to prevent overfitting) and a Dense layer (for classification) to the output of our model.

model = tf.keras.Sequential([  
    # Explicitly define the input shape so the model can be properly  
    # loaded by the TFLiteConverter  
    tf.keras.layers.InputLayer(input_shape=IMAGE_SIZE + (3,)),  
    hub.KerasLayer(model_handle, trainable=do_fine_tuning),  
    tf.keras.layers.Dropout(rate=0.2),  
    tf.keras.layers.Dense(len(class_names),  
                          kernel_regularizer=tf.keras.regularizers.l2(0.0001))  
])  
model.build((None,)+IMAGE_SIZE+(3,))  
model.summary()

As we can see, only 6,405 out of 2,264,389 parameters are trainable. This will speed up training significantly.

Model: "sequential_1"  
_________________________________________________________________  
 Layer (type)                Output Shape              Param #     
=================================================================  
 keras_layer (KerasLayer)    (None, 1280)              2257984     
                                                                   
 dropout (Dropout)           (None, 1280)              0           
                                                                   
 dense (Dense)               (None, 5)                 6405        
                                                                   
=================================================================  
Total params: 2,264,389  
Trainable params: 6,405  
Non-trainable params: 2,257,984  
_________________________________________________________________

We compile the model using a learning rate of 0.005 and categorical crossentropy for the loss function.

model.compile(  
  optimizer=tf.keras.optimizers.SGD(learning_rate=0.005, momentum=0.9),   
  loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True, label_smoothing=0.1),  
  metrics=['accuracy'])

We define the number of steps per epoch.

steps_per_epoch = train_size // BATCH_SIZE  
validation_steps = valid_size // BATCH_SIZE

Finally, we train the model.

hist = model.fit(  
    train_ds,  
    epochs=5, steps_per_epoch=steps_per_epoch,  
    validation_data=val_ds,  
    validation_steps=validation_steps).history

We plot the training and validation loss over time.

plt.figure()  
plt.ylabel("Loss (training and validation)")  
plt.xlabel("Training Steps")  
plt.ylim([0,2])  
plt.plot(hist["loss"])  
plt.plot(hist["val_loss"])
plt.figure()  
plt.ylabel("Accuracy (training and validation)")  
plt.xlabel("Training Steps")  
plt.ylim([0,1])  
plt.plot(hist["accuracy"])  
plt.plot(hist["val_accuracy"])
![](https://d3i8kmg9ozyieo.cloudfront.net/transfer-learning-2.png)
![](https://d3i8kmg9ozyieo.cloudfront.net/transfer-learning-3.png)

As we can see, it didn’t take long for the accuracy of the model to start hovering around 90%.

We can use the model to infer the class of a specific sample in the dataset.

x, y = next(iter(val_ds))  
image = x[0, :, :, :]  
true_index = np.argmax(y[0])  
plt.imshow(image)  
plt.axis('off')  
plt.show()
prediction_scores = model.predict(np.expand_dims(image, axis=0))  
predicted_index = np.argmax(prediction_scores)  
print("True label: " + class_names[true_index])  
print("Predicted label: " + class_names[predicted_index])
![](https://d3i8kmg9ozyieo.cloudfront.net/transfer-learning-4.png)
``` True label: sunflowers Predicted label: sunflowers ```

As we can see, it accurately classified the image.


Profile picture

Written by Cory Maklin Genius is making complex ideas simple, not making simple ideas complex - Albert Einstein You should follow them on Twitter