Main Content

Try Multiple Pretrained Networks for Transfer Learning

This example shows how to configure an experiment that replaces layers of different pretrained networks for transfer learning. Transfer learning is commonly used in deep learning applications. You can take a pretrained network and use it as a starting point to learn a new task. Fine-tuning a network with transfer learning is usually much faster and easier than training a network with randomly initialized weights from scratch. You can quickly transfer learned features to a new task using a smaller number of training images.

There are many pretrained networks available in Deep Learning Toolbox™. These pretrained networks have different characteristics that matter when choosing a network to apply to your problem. The most important characteristics are network accuracy, speed, and size. Choosing a network is generally a tradeoff between these characteristics. To compare the performance of different pretrained networks for your task, edit this experiment and specify which pretrained networks to use.

This experiment requires the Deep Learning Toolbox Modelfor GoogLeNet Networksupport package and the Deep Learning Toolbox Modelfor ResNet-18 Networksupport package. Before you run the experiment, install these support packages by calling thegooglenetandresnet18functions and clicking the download links. For more information on other pretrained networks that you can download from the Add-On Explorer, seePretrained Deep Neural Networks.

Open Experiment

First, open the example.Experiment Managerloads a project with a preconfigured experiment that you can inspect and run. To open the experiment, in theExperiment Browserpane, double-click the name of the experiment (TransferLearningExperiment).

Built-in training experiments consist of a description, a table of hyperparameters, a setup function, and a collection of metric functions to evaluate the results of the experiment. For more information, seeConfigure Built-In Training Experiment.

TheDescriptionfield contains a textual description of the experiment. For this example, the description is:

Perform transfer learning by replacing layers in a pretrained network.

TheHyperparameterssection specifies the strategy (Exhaustive Sweep) and hyperparameter values to use for the experiment. When you run the experiment, Experiment Manager trains the network using every combination of hyperparameter values specified in the hyperparameter table. In this example, the hyperparameterNetworkNamespecifies the network to train and the value of the training optionminiBatchSize.

TheSetup Functionconfigures the training data, network architecture, and training options for the experiment. The input to the setup function is a structure with fields from the hyperparameter table. The setup function returns three outputs that you use to train a network for image classification problems. In this example, the setup function:

  • Loads a pretrained network corresponding to the hyperparameterNetworkName.

networkName = params.NetworkName;
switchnetworkNamecase"squeezenet"net = squeezenet; miniBatchSize = 128;case"googlenet"net = googlenet; miniBatchSize = 128;case"resnet18"net = resnet18; miniBatchSize = 128;case"mobilenetv2"net = mobilenetv2; miniBatchSize = 128;case"resnet50"net = resnet50; miniBatchSize = 128;case"resnet101"net = resnet101; miniBatchSize = 64;case"inceptionv3"net = inceptionv3; miniBatchSize = 64;case"inceptionresnetv2"net = inceptionresnetv2; miniBatchSize = 64;otherwiseerror("Undefined network selection.");end
  • Downloads and extracts the Flowers data set, which is about 218 MB. For more information on this data set, seeImage Data Sets.

url ="http://download.tensorflow.org/example_images/flower_photos.tgz"; downloadFolder = tempdir; filename = fullfile(downloadFolder,"flower_dataset.tgz");
imageFolder = fullfile(downloadFolder,"flower_photos");if~exist(imageFolder,"dir") disp("Downloading Flower Dataset (218 MB)...") websave(filename,url); untar(filename,downloadFolder)end
imds = imageDatastore(imageFolder,...IncludeSubfolders=true,...LabelSource="foldernames");
[imdsTrain, imdsValidation] = splitEachLabel (imd, 0.9); inputSize = net.Layers(1).InputSize; augimdsTrain = augmentedImageDatastore(inputSize,imdsTrain); augimdsValidation = augmentedImageDatastore(inputSize,imdsValidation);
  • Replaces the learnable layers of the pretrained network to perform transfer learning. The helper functionfindLayersToReplace, which is listed inAppendix 2at the end of this example, determines the layers in the network architecture to replace for transfer learning. For more information on the available pretrained networks, seePretrained Deep Neural Networks.

lgraph = layerGraph(净);[learnableLayer classLayer] = findLayersToReplace(lgraph); numClasses = numel(categories(imdsTrain.Labels));
ifisa(learnableLayer,"nnet.cnn.layer.FullyConnectedLayer") newLearnableLayer = fullyConnectedLayer(numClasses,...Name="new_fc",...WeightLearnRateFactor=10,...BiasLearnRateFactor=10);elseifisa(learnableLayer,"nnet.cnn.layer.Convolution2DLayer") newLearnableLayer = convolution2dLayer(1,numClasses,...Name="new_conv",...WeightLearnRateFactor=10,...BiasLearnRateFactor=10);end
lgraph = replaceLayer(lgraph,learnableLayer.Name,newLearnableLayer);
newClassLayer = classificationLayer(Name="new_classoutput");lgraph = replaceLayer (lgraph classLayer。的名字,不wClassLayer);
  • Defines atrainingOptionsobject for the experiment. The example trains the network for 10 epochs, using an initial learning rate of 0.0003 and validating the network every 5 epochs.

validationFrequencyEpochs = 5;
numObservations = augimdsTrain.NumObservations; numIterationsPerEpoch = floor(numObservations/miniBatchSize); validationFrequency = validationFrequencyEpochs * numIterationsPerEpoch;
options = trainingOptions("sgdm",...MaxEpochs=10,...MiniBatchSize=miniBatchSize,...InitialLearnRate=3e-4,...Shuffle="every-epoch",...ValidationData=augimdsValidation,...ValidationFrequency=validationFrequency,...Verbose=false);

To inspect the setup function, underSetup Function, clickEdit. The setup function opens in MATLAB® Editor. In addition, the code for the setup function appears inAppendix 1at the end of this example.

TheMetricssection specifies optional functions that evaluate the results of the experiment. This example does not include any custom metric functions.

Run Experiment

When you run the experiment, Experiment Manager trains the network defined by the setup function six times. Each trial uses a different combination of hyperparameter values. By default, Experiment Manager runs one trial at a time. If you have Parallel Computing Toolbox™, you can run multiple trials at the same time or offload your experiment as a batch job in a cluster.

  • To run one trial of the experiment at a time, on the Experiment Manager toolstrip, underMode, selectSequentialand clickRun.

  • To run multiple trials at the same time, underMode, selectSimultaneousand clickRun. If there is no current parallel pool, Experiment Manager starts one using the default cluster profile. Experiment Manager then executes multiple simultaneous trials, depending on the number of parallel workers available. For best results, before you run your experiment, start a parallel pool with as many workers as GPUs. For more information, seeUse Experiment Manager to Train Networks in ParallelandGPU Support by Release(Parallel Computing Toolbox).

  • To offload the experiment as a batch job, underMode, selectBatch SequentialorBatch Simultaneous, specify yourClusterandPool Size, and clickRun. For more information, seeOffload Experiments as Batch Jobs to Cluster.

A table of results displays the accuracy and loss for each trial. While the experiment is running, clickTraining Plotto display the training plot and track the progress of each trial. ClickConfusion Matrixto display the confusion matrix for the validation data in each completed trial.

When the experiment finishes, you can sort the results table by column, filter trials by using theFilterspane, or record observations by adding annotations. For more information, seeSort, Filter, and Annotate Experiment Results.

To test the performance of an individual trial, export the trained network or the training information for the trial. On theExperiment Managertoolstrip, selectExport>Trained NetworkorExport>Training Information, respectively. For more information, seenetandinfo. To save the contents of the results table as atablearray in the MATLAB workspace, selectExport>Results Table.

Close Experiment

In theExperiment Browserpane, right-click the name of the project and selectClose Project. Experiment Manager closes all of the experiments and results contained in the project.

Appendix 1: Setup Function

This function configures the training data, network architecture, and training options for the experiment.

Input

  • paramsis a structure with fields from the Experiment Manager hyperparameter table.

Output

  • augimdsTrainis an augmented image datastore for the training data.

  • lgraphis a layer graph that defines the neural network architecture.

  • optionsis atrainingOptionsobject.

function[augimdsTrain,lgraph,options] = TransferLearningExperiment_setup1(params) networkName = params.NetworkName;switchnetworkNamecase"squeezenet"net = squeezenet; miniBatchSize = 128;case"googlenet"net = googlenet; miniBatchSize = 128;case"resnet18"net = resnet18; miniBatchSize = 128;case"mobilenetv2"net = mobilenetv2; miniBatchSize = 128;case"resnet50"net = resnet50; miniBatchSize = 128;case"resnet101"net = resnet101; miniBatchSize = 64;case"inceptionv3"net = inceptionv3; miniBatchSize = 64;case"inceptionresnetv2"net = inceptionresnetv2; miniBatchSize = 64;otherwiseerror("Undefined network selection.");endurl ="http://download.tensorflow.org/example_images/flower_photos.tgz"; downloadFolder = tempdir; filename = fullfile(downloadFolder,"flower_dataset.tgz");imageFolder = fullfile(downloadFolder,"flower_photos");if~exist(imageFolder,"dir") disp("Downloading Flower Dataset (218 MB)...") websave(filename,url); untar(filename,downloadFolder)endimds = imageDatastore(imageFolder,...IncludeSubfolders=true,...LabelSource="foldernames");[imdsTrain, imdsValidation] = splitEachLabel (imd, 0.9); inputSize = net.Layers(1).InputSize; augimdsTrain = augmentedImageDatastore(inputSize,imdsTrain); augimdsValidation = augmentedImageDatastore(inputSize,imdsValidation); lgraph = layerGraph(net); [learnableLayer,classLayer] = findLayersToReplace(lgraph); numClasses = numel(categories(imdsTrain.Labels));ifisa(learnableLayer,"nnet.cnn.layer.FullyConnectedLayer") newLearnableLayer = fullyConnectedLayer(numClasses,...Name="new_fc",...WeightLearnRateFactor=10,...BiasLearnRateFactor=10);elseifisa(learnableLayer,"nnet.cnn.layer.Convolution2DLayer") newLearnableLayer = convolution2dLayer(1,numClasses,...Name="new_conv",...WeightLearnRateFactor=10,...BiasLearnRateFactor=10);endlgraph = replaceLayer(lgraph,learnableLayer.Name,newLearnableLayer); newClassLayer = classificationLayer(Name="new_classoutput");lgraph = replaceLayer (lgraph classLayer。的名字,不wClassLayer); validationFrequencyEpochs = 5; numObservations = augimdsTrain.NumObservations; numIterationsPerEpoch = floor(numObservations/miniBatchSize); validationFrequency = validationFrequencyEpochs * numIterationsPerEpoch; options = trainingOptions("sgdm",...MaxEpochs=10,...MiniBatchSize=miniBatchSize,...InitialLearnRate=3e-4,...Shuffle="every-epoch",...ValidationData=augimdsValidation,...ValidationFrequency=validationFrequency,...Verbose=false);end

Appendix 2: Find Layers to Replace

This function finds the single classification layer and the preceding learnable (fully connected or convolutional) layer of the layer graphlgraph.

function[learnableLayer classLayer] = findLayersToReplace(lgraph)if~isa(lgraph,"nnet.cnn.LayerGraph") error("Argument must be a LayerGraph object.")endsrc = string(lgraph.Connections.Source); dst = string(lgraph.Connections.Destination); layerNames = string({lgraph.Layers.Name}'); isClassificationLayer = arrayfun(@(l)...(isa(l,"nnet.cnn.layer.ClassificationOutputLayer")|isa(l,"nnet.layer.ClassificationLayer")),...lgraph.Layers);ifsum(isClassificationLayer) ~= 1 error("Layer graph must have a single classification layer.")endclassLayer = lgraph.Layers(isClassificationLayer); currentLayerIdx = find(isClassificationLayer);whiletrueifnumel(currentLayerIdx) ~= 1 error("Layer graph must have a single learnable layer preceding the classification layer.")endcurrentLayerType = class(lgraph.Layers(currentLayerIdx)); isLearnableLayer = ismember(currentLayerType,...["nnet.cnn.layer.FullyConnectedLayer","nnet.cnn.layer.Convolution2DLayer"]);ifisLearnableLayer learnableLayer = lgraph.Layers(currentLayerIdx);returnendcurrentDstIdx = find(layerNames(currentLayerIdx) == dst); currentLayerIdx = find(src(currentDstIdx) == layerNames);endend

See Also

Apps

Functions

Related Topics