This post will cover class model visualization, which is described in Section 2 of this paper. Class model visualization is a technique for using a trained classification CNN to create an image that is representative of a particular class for that CNN. A class model for “bird” maximally activates the CNN output neuron corresponding to the “bird” class.
Here is Figure 1 from the paper, showing example class models for twelve different classes:
The paper by Simonyan et al. that introduces class models also introduces backpropagation for creating saliency maps. Saliency maps are discussed separately, in this post.
The Expression for a Class Model
A class model is an image I that produces a high score when fed in to an already-trained CNN. A class model I is summarized by the following expression:
Here, we are doing “arg max” over images I, which means we are trying to find a particular image I (the class model) such that the second piece of the expression, , is maximized.
This second piece of the expression is simply the CNN’s raw class score for that image, , followed by an L2 regularization term (review of regularization here).
In other words, the expression is another way of saying that a class model is an image I that maximizes a class score.
Why include the L2 regularization term?
The L2 regularization term encourages the learned image not to contain extreme values. Recall that L2 regularization tries to make all values small in magnitude (but not necessarily zero), while L1 regularization tries to make as many values zero as possible. It’s possible to create class models using different kinds of regularization.
The repository pytorch-cnn-visualizations provides the following example of the effect regularization has on the appearance of the class model:
First, here is a gif showing the process of learning a class model for the “flamingo” class without any regularization at all:
We can see that the resulting image includes many bright colors (high values).
Next, here’s a gif showing the process of learning a “flamingo” class model with L2 regularization:
This one includes more greyish values, and fewer high values, as we would expect from applying L2 regularization.
Finally, here is a gif showing the process of learning a “flamingo” class model with L1 regularization:
We can see that this class model includes many black (zero) values, as we would expect from applying L1 regularization.
There is no “one right way” to regularize a class model. The whole point of class models is to provide a curious human with insight into what a CNN perceives about a class. L2 regularization is a good option, but you can explore other forms of regularization too.
Why Maximize Score Rather than Probability?
Recall that a classification CNN produces raw class scores Sc, which are then converted into probabilities Pc using a softmax layer:
In the class model approach, we try to maximize the unnormalized class score Sc, rather than the class probability Pc. This is because there are two ways to maximize the class probability Pc:
- We can maximize the raw score for the class of interest
- We can minimize the raw scores for the other classes
We don’t want to end up doing (2), because if we do (2) it won’t be clear how to interpret the resulting image. Therefore, we ignore the softmax layer and directly maximize the raw score Sc for the class of interest.
Steps for Creating a Class Model
We can learn an image I to produce a high score for the class “bird” through the following steps:
- Train a classification CNN.
- Create a random zero-centered image (assuming the classification CNN was trained on zero-centered image data.)
- Repeat the following steps: (a) Do a forward pass on the image to compute the current class scores; (b) use the backpropagation algorithm to find the gradient of the “bird” neuron output (score) with respect to the image pixels; (c) make a small update to the image so that it will produce an even higher “bird” score on the next forward pass.
This is gradient ascent, in which we are making updates to the image to maximize a score (in contrast to gradient descent, where we are trying to make updates to minimize a score.)
We can make a class model visualization for every class that the CNN is trained on. We train the CNN once, and then we repeat steps (2) and (3) for each possible choice of class c.
Here is a Pytorch implementation of class models: pytorch-cnn-visualizations/src/generate_class_specific_samples.py . The following gif is from the the repository, and shows the process of learning a class model for target class “Spider” starting from a randomly-initialized image:
Class models can help us understand what a CNN thinks a certain class “looks like.” They are a useful tool for gaining insight into what our trained CNN has learned about each class.
- Original research paper: Simonyan K, Vedaldi A, Zisserman A. Deep inside convolutional networks: Visualising image classification models and saliency maps. arXiv preprint arXiv:1312.6034. 2013 Dec 20. Cited by 1,479
- Fei-Fei Li, Justin Johnson, Serena Yeung. Stanford CS 231n Lecture 12: Visualizing and Understanding (slides 21-24)
The featured image shows an ostrich and a class model image for the “ostrich” class of a VGG16 CNN, and is modified from images in the following GitHub repository: saketd403/Visualising-Image-Classification-Models-and-Saliency-Maps