Main Content

Extract Image Features Using Pretrained Network

This example shows how to extract learned image features from a pretrained convolutional neural network and use those features to train an image classifier. Feature extraction is the easiest and fastest way to use the representational power of pretrained deep networks. For example, you can train a support vector machine (SVM) usingfitcecoc(Statistics and Machine Learning Toolbox™) on the extracted features. Because feature extraction only requires a single pass through the data, it is a good starting point if you do not have a GPU to accelerate network training with.

Load Data

Unzip and load the sample images as an image datastore.imageDatastoreautomatically labels the images based on folder names and stores the data as anImageDatastoreobject. An image datastore lets you store large image data, including data that does not fit in memory. Split the data into 70% training and 30% test data.

unzip('MerchData.zip'); imds = imageDatastore('MerchData','IncludeSubfolders',true,'LabelSource','foldernames'); [imdsTrain,imdsTest] = splitEachLabel(imds,0.7,'randomized');

There are now 55 training images and 20 validation images in this very small data set. Display some sample images.

numTrainImages = numel(imdsTrain.Labels); idx = randperm(numTrainImages,16); figurefori = 1:16 subplot(4,4,i) I = readimage(imdsTrain,idx(i)); imshow(I)end

Figure contains 16 axes objects. Axes object 1 contains an object of type image. Axes object 2 contains an object of type image. Axes object 3 contains an object of type image. Axes object 4 contains an object of type image. Axes object 5 contains an object of type image. Axes object 6 contains an object of type image. Axes object 7 contains an object of type image. Axes object 8 contains an object of type image. Axes object 9 contains an object of type image. Axes object 10 contains an object of type image. Axes object 11 contains an object of type image. Axes object 12 contains an object of type image. Axes object 13 contains an object of type image. Axes object 14 contains an object of type image. Axes object 15 contains an object of type image. Axes object 16 contains an object of type image.

Load Pretrained Network

Load a pretrained ResNet-18 network. If the Deep Learning Toolbox Modelfor ResNet-18 Networksupport package is not installed, then the software provides a download link. ResNet-18 is trained on more than a million images and can classify images into 1000 object categories, such as keyboard, mouse, pencil, and many animals. As a result, the model has learned rich feature representations for a wide range of images.

net = resnet18
net = DAGNetwork with properties: Layers: [71x1 nnet.cnn.layer.Layer] Connections: [78x2 table] InputNames: {'data'} OutputNames: {'ClassificationLayer_predictions'}

Analyze the network architecture. The first layer, the image input layer, requires input images of size 224-by-224-by-3, where 3 is the number of color channels.

inputSize = net.Layers(1).InputSize; analyzeNetwork(net)

Extract Image Features

The network requires input images of size 224-by-224-by-3, but the images in the image datastores have different sizes. To automatically resize the training and test images before they are input to the network, create augmented image datastores, specify the desired image size, and use these datastores as input arguments toactivations.

augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain); augimdsTest = augmentedImageDatastore(inputSize(1:2),imdsTest);

网络结构层次representation of input images. Deeper layers contain higher-level features, constructed using the lower-level features of earlier layers. To get the feature representations of the training and test images, useactivationson the global pooling layer,'pool5',at the end of the network. The global pooling layer pools the input features over all spatial locations, giving 512 features in total.

layer ='pool5'; featuresTrain = activations(net,augimdsTrain,layer,'OutputAs','rows'); featuresTest = activations(net,augimdsTest,layer,'OutputAs','rows'); whosfeaturesTrain
Name Size Bytes Class Attributes featuresTrain 55x512 112640 single

提取类标签的训练和测试t data.

YTrain = imdsTrain.Labels; YTest = imdsTest.Labels;

Fit Image Classifier

Use the features extracted from the training images as predictor variables and fit a multiclass support vector machine (SVM) usingfitcecoc(Statistics and Machine Learning Toolbox).

classifier = fitcecoc(featuresTrain,YTrain);

Classify Test Images

Classify the test images using the trained SVM model using the features extracted from the test images.

YPred = predict(classifier,featuresTest);

Display four sample test images with their predicted labels.

idx = [1 5 10 15]; figurefori = 1:numel(idx) subplot(2,2,i) I = readimage(imdsTest,idx(i)); label = YPred(idx(i)); imshow(I) title(char(label))end

Figure contains 4 axes objects. Axes object 1 with title MathWorks Cap contains an object of type image. Axes object 2 with title MathWorks Cube contains an object of type image. Axes object 3 with title MathWorks Playing Cards contains an object of type image. Axes object 4 with title MathWorks Screwdriver contains an object of type image.

Calculate the classification accuracy on the test set. Accuracy is the fraction of labels that the network predicts correctly.

精度=意味着(YPred = =次)
accuracy = 1

Train Classifier on Shallower Features

You can also extract features from an earlier layer in the network and train a classifier on those features. Earlier layers typically extract fewer, shallower features, have higher spatial resolution, and a larger total number of activations. Extract the features from the'res3b_relu'layer. This is the final layer that outputs 128 features and the activations have a spatial size of 28-by-28.

layer ='res3b_relu'; featuresTrain = activations(net,augimdsTrain,layer); featuresTest = activations(net,augimdsTest,layer); whosfeaturesTrain
Name Size Bytes Class Attributes featuresTrain 28x28x128x55 22077440 single

The extracted features used in the first part of this example were pooled over all spatial locations by the global pooling layer. To achieve the same result when extracting features in earlier layers, manually average the activations over all spatial locations. To get the features on the formN-by-C, whereNis the number of observations andCis the number of features, remove the singleton dimensions and transpose.

featuresTrain = squeeze(mean(featuresTrain,[1 2]))'; featuresTest = squeeze(mean(featuresTest,[1 2]))'; whosfeaturesTrain
Name Size Bytes Class Attributes featuresTrain 55x128 28160 single

Train an SVM classifier on the shallower features. Calculate the test accuracy.

classifier = fitcecoc(featuresTrain,YTrain); YPred = predict(classifier,featuresTest); accuracy = mean(YPred == YTest)
accuracy = 0.9500

Both trained SVMs have high accuracies. If the accuracy is not high enough using feature extraction, then try transfer learning instead. For an example, seeTrain Deep Learning Network to Classify New Images. For a list and comparison of the pretrained networks, seePretrained Deep Neural Networks.

See Also

(Statistics and Machine Learning Toolbox)|

Related Topics