Transfer Learning
Long gone are the days in which data practitioners trained machine learning models from scratch themselves. Unless you have a very specific use case, you’re better off leveraging the pre-trained models that are made available in public repositories such as Tensorflow Hub and Hugging Face. The process by which we customize a pre-trained model for a given task is known as transfer learning.
There are two techniques under the umbrella of transfer learning:
- Feature extraction: Uses the output (prior to the sigmoid function) from a pre-trained model as the features to a new classifier. When training the new model, you keep the weights of the pre-trained model fixed.
- Fine-tuning: Unfreezes the weights in the layers of the base model and jointly trains them along with the ones in the newly-added classifier layers.
In this article, we will walkthrough an example of feature extraction. If you’d like to learn more about fine-tuning, you can checkout this article on the subject.
Python
We’ll be using the code taken from the TF Hub for TF2: Retraining an image classifier tutorial since I couldn’t make it any clearer myself.
I recommend using Google Collab to train and run the model since it already has all the dependencies installed. You can view what GPU your runtime has access to as follows:
! nvidia-smi
We begin by importing the following libraries:
import tensorflow as tf
import tensorflow_hub as hub
import matplotlib.pylab as plt
import numpy as np
We will use the MobileNetV2 architecture trained on the ImageNet dataset as the base model.
model_handle = "[https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/feature_vector/4](https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/feature_vector/4)"
In TensorFlow Hub, you can download classification and feature vector models. The feature vector models are specifically designed for transfer learning. As the name implies, their output is a feature vector (i.e. not piped through a sigmoid function). Thus, you don’t need to remove the output layer of the model when doing transfer learning. You can simply add additional trainable layers.
The model was trained using images of size 224 by 224 pixels. We set a sensible default value for the batch size.
IMAGE_SIZE = (224, 224)
BATCH_SIZE = 16
The Keras library provides a utility function for retrieving files from the GCP.
data_dir = tf.keras.utils.get_file(
'flower_photos',
'[https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz'](https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz%27),
untar=True)
Using the images we just downloaded, we construct the training and validation datasets. We perform data augmentation by transforming the images slightly (i.e. rotating, translating, flipping).
def build_dataset(subset):
return tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split=.20,
subset=subset,
label_mode="categorical",
# Seed needs to provided when using validation_split and shuffle = True.
# A fixed seed is used so that the validation set is stable across runs.
seed=123,
image_size=IMAGE_SIZE,
batch_size=1)
train_ds = build_dataset("training")
class_names = tuple(train_ds.class_names)
train_size = train_ds.cardinality().numpy()
train_ds = train_ds.unbatch().batch(BATCH_SIZE)
train_ds = train_ds.repeat()
normalization_layer = tf.keras.layers.Rescaling(1. / 255)
preprocessing_model = tf.keras.Sequential([normalization_layer])
do_data_augmentation = False #@param {type:"boolean"}
if do_data_augmentation:
preprocessing_model.add(
tf.keras.layers.RandomRotation(40))
preprocessing_model.add(
tf.keras.layers.RandomTranslation(0, 0.2))
preprocessing_model.add(
tf.keras.layers.RandomTranslation(0.2, 0))
# Like the old tf.keras.preprocessing.image.ImageDataGenerator(),
# image sizes are fixed when reading, and then a random zoom is applied.
# If all training inputs are larger than image_size, one could also use
# RandomCrop with a batch size of 1 and rebatch later.
preprocessing_model.add(
tf.keras.layers.RandomZoom(0.2, 0.2))
preprocessing_model.add(
tf.keras.layers.RandomFlip(mode="horizontal"))
train_ds = train_ds.map(lambda images, labels:
(preprocessing_model(images), labels))
val_ds = build_dataset("validation")
valid_size = val_ds.cardinality().numpy()
val_ds = val_ds.unbatch().batch(BATCH_SIZE)
val_ds = val_ds.map(lambda images, labels:
(normalization_layer(images), labels))
We freeze the weights of the MobileNetV2 model.
do_fine_tuning = False
We add a Dropout layer (to prevent overfitting) and a Dense layer (for classification) to the output of our model.
model = tf.keras.Sequential([
# Explicitly define the input shape so the model can be properly
# loaded by the TFLiteConverter
tf.keras.layers.InputLayer(input_shape=IMAGE_SIZE + (3,)),
hub.KerasLayer(model_handle, trainable=do_fine_tuning),
tf.keras.layers.Dropout(rate=0.2),
tf.keras.layers.Dense(len(class_names),
kernel_regularizer=tf.keras.regularizers.l2(0.0001))
])
model.build((None,)+IMAGE_SIZE+(3,))
model.summary()
As we can see, only 6,405 out of 2,264,389 parameters are trainable. This will speed up training significantly.
Model: "sequential_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
keras_layer (KerasLayer) (None, 1280) 2257984
dropout (Dropout) (None, 1280) 0
dense (Dense) (None, 5) 6405
=================================================================
Total params: 2,264,389
Trainable params: 6,405
Non-trainable params: 2,257,984
_________________________________________________________________
We compile the model using a learning rate of 0.005 and categorical crossentropy for the loss function.
model.compile(
optimizer=tf.keras.optimizers.SGD(learning_rate=0.005, momentum=0.9),
loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True, label_smoothing=0.1),
metrics=['accuracy'])
We define the number of steps per epoch.
steps_per_epoch = train_size // BATCH_SIZE
validation_steps = valid_size // BATCH_SIZE
Finally, we train the model.
hist = model.fit(
train_ds,
epochs=5, steps_per_epoch=steps_per_epoch,
validation_data=val_ds,
validation_steps=validation_steps).history
We plot the training and validation loss over time.
plt.figure()
plt.ylabel("Loss (training and validation)")
plt.xlabel("Training Steps")
plt.ylim([0,2])
plt.plot(hist["loss"])
plt.plot(hist["val_loss"])
plt.figure()
plt.ylabel("Accuracy (training and validation)")
plt.xlabel("Training Steps")
plt.ylim([0,1])
plt.plot(hist["accuracy"])
plt.plot(hist["val_accuracy"])
As we can see, it didn’t take long for the accuracy of the model to start hovering around 90%.
We can use the model to infer the class of a specific sample in the dataset.
x, y = next(iter(val_ds))
image = x[0, :, :, :]
true_index = np.argmax(y[0])
plt.imshow(image)
plt.axis('off')
plt.show()
prediction_scores = model.predict(np.expand_dims(image, axis=0))
predicted_index = np.argmax(prediction_scores)
print("True label: " + class_names[true_index])
print("Predicted label: " + class_names[predicted_index])
As we can see, it accurately classified the image.