Main Content

Texture Classification with Wavelet Image Scattering

This example shows how to classify textures using wavelet image scattering. In addition to Wavelet Toolbox™, this example also requires Parallel Computing Toolbox™ and Image Processing Toolbox™.

In a digital image, texture provides information about the spatial arrangement of color or pixel intensities. Particular spatial arrangements of color or pixel intensities correspond to different appearances and consistencies of the physical material being imaged. Texture classification and segmentation of images has a number of important application areas. A particularly important example is biomedical image analysis where normal and pathologic states are often characterized by morphological and histological characteristics which manifest as differences in texture [4].

Wavelet Image Scattering

For classification problems, it is often useful to map the data into some alternative representation which discards irrelevant information while retaining the discriminative properties of each class. Wavelet image scattering constructs low-variance representations of images which are insensitive to translations and small deformations. Because translations and small deformations in the image do not affect class membership, scattering transform coefficients provide features from which you can build robust classification models.

Wavelet scattering works by cascading the image through a series of wavelet transforms, nonlinearities, and averaging [1][3][5]. The result of thisdeepfeature extraction is that images in the same class are moved closer to each other in the scattering transform representation, while images belonging to different classes are moved farther apart.

KTH-TIPS

This example uses a publicly available texture database, the KTH-TIPS (Textures under varying Illumination, Pose, and Scale) image database [6]. The KTH-TIPS dataset used in this example is the grayscale version. There are 810 images in total with 10 textures and 81 images per texture. The majority of images are 200-by-200 in size. This example assumes you have downloaded the KTH-TIPS grayscale dataset and untarred it so that the 10 texture classes are contained in separate subfolders of a common folder. Each subfolder is named for the class of textures it contains. Untarring the downloadedkth_tips_grey_200x200.tarfile is sufficient to provide a top-level folder KTH_TIPS and the required subfolder structure.

Use theimageDatastoreto read the data. Set thelocationproperty of theimageDatastoreto the folder containing the KTH-TIPS database that you have access to.

location = fullfile(tempdir,'kth_tips_grey_200x200','KTH_TIPS'); Imds = imageDatastore(location,'IncludeSubFolders',true,'FileExtensions','.png','LabelSource','foldernames');

Randomly select and visualize 20 images from the dataset.

numImages = 810; perm = randperm(numImages,20);fornp = 1:20 subplot(4,5,np) im = imread(Imds.Files{perm(np)}); imagesc(im); colormapgray; axisoff;end

Texture Classification

This example uses MATLAB™'s parallel processing capability through thetallarray interface. Start the parallel pool if one is not currently running.

ifisempty(gcp) parpool;end
Starting parallel pool (parpool) using the 'local' profile ... Connected to the parallel pool (number of workers: 6).

For reproducibility, set the random number generator. Shuffle the files of the KTH-TIPS dataset and split the 810 images into two randomly selected sets, one for training and one held-out set for testing. Use approximately 80% of the images for building a predictive model from the scattering transform and use the remainder for testing the model.

rng(100) Imds = imageDatastore(location,'IncludeSubFolders',true,'FileExtensions','.png','LabelSource','foldernames'); Imds = shuffle(Imds); [trainImds,testImds] = splitEachLabel(Imds,0.8);

We now have two datasets. The training set consists of 650 images, with 65 images per texture. The testing set consists of 160 images, with 16 images per texture. To verify, count the labels in each dataset.

countEachLabel(trainImds)
ans=10×2 tableLabel Count ______________ _____ aluminium_foil 65 brown_bread 65 corduroy 65 cotton 65 cracker 65 linen 65 orange_peel 65 sandpaper 65 sponge 65 styrofoam 65
countEachLabel(testImds)
ans=10×2 tableLabel Count ______________ _____ aluminium_foil 16 brown_bread 16 corduroy 16 cotton 16 cracker 16 linen 16 orange_peel 16 sandpaper 16 sponge 16 styrofoam 16

Createtallarrays for the resized images.

Ttrain = tall(trainImds); Ttest = tall(testImds);

Create a scattering framework for an image input size of 200-by-200 with anInvarianceScaleof 150. The invariance scale hyperparameter is the only one we set in this example. For the other hyperparameters of the scattering transform, use the default values.

sn = waveletScattering2('ImageSize',[200 200],'InvarianceScale',150);

To extract features for classification for each the training and test sets, use thehelperScatImages_meanfunction. The code forhelperScatImages_meanis at the end of this example.helperScatImages_mean调整图像的大小,大小和常见的200 - 200uses the scattering framework,sn, to obtain the feature matrix. In this case, each feature matrix is 391-by-7-by-7. There are 391 scattering paths and each scattering coefficient image is 7-by-7. Finally,helperScatImages_meanobtains the mean along the 2nd and 3rd dimensions to obtain a 391 element feature vector for each image. This is a significant reduction in data from 40,000 elements down to 391.

trainfeatures = cellfun(@(x)helperScatImages_mean(sn,x),Ttrain,'Uni',0); testfeatures = cellfun(@(x)helperScatImages_mean(sn,x),Ttest,'Uni',0);

Usingtall'sgathercapability, gather all the training and test feature vectors and concatenate them into matrices.

Trainf = gather(trainfeatures);
Evaluating tall expression using the Parallel Pool 'local': - Pass 1 of 1: Completed in 1 min 39 sec Evaluation completed in 1 min 39 sec
trainfeatures = cat(2,Trainf{:}); Testf = gather(testfeatures);
Evaluating tall expression using the Parallel Pool 'local': - Pass 1 of 1: Completed in 23 sec Evaluation completed in 23 sec
testfeatures = cat(2,Testf{:});

The previous code results in two matrices with row dimensions 391 and column dimension equal to the number of images in the training and test sets, respectively. So each column is a feature vector.

PCA Model and Prediction

This example constructs a simple classifier based on the principal components of the scattering feature vectors for each class. The classifier is implemented in the functionshelperPCAModelandhelperPCAClassifier. The functionhelperPCAModeldetermines the principal components for each digit class based on the scattering features. The code forhelperPCAModelis at the end of this example. The functionhelperPCAClassifierclassifies the held-out test data by finding the closest match (best projection) between the principal components of each test feature vector with the training set and assigning the class accordingly. The code forhelperPCAClassifieris at the end of this example.

model = helperPCAModel(trainfeatures,30,trainImds.Labels); predlabels = helperPCAClassifier(testfeatures,model);

After constructing the model and classifying the test set, determine the accuracy of the test set classification.

accuracy = sum(testImds.Labels == predlabels)./numel(testImds.Labels)*100
accuracy = 99.3750

We have achieved 99.375% correct classification, or a 0.625% error rate for the 160 images in the test set. A plot of the confusion matrix shows that our simple model misclassified one texture.

figure confusionchart(testImds.Labels,predlabels)

Summary

In this example, we used wavelet image scattering to create low-variance representations of textures for classification. Using the scattering transform and a simple principal components classifier, we achieved 99.375% correct classification on a held-out test set. This result is comparable to state-of-the-art performance on the KTH-TIPS database.[2]

References

[1] Bruna, J., and S. Mallat. "Invariant Scattering Convolution Networks."IEEE Transactions on Pattern Analysis and Machine Intelligence. Vol. 35, Number 8, 2013, pp. 1872–1886.

[2] Hayman, E., B. Caputo, M. Fritz, and J. O. Eklundh. “On the Significance of Real-World Conditions for Material Classification.” InComputer Vision - ECCV 2004, edited by Tomás Pajdla and Jiří Matas, 3024:253–66. Berlin, Heidelberg: Springer Berlin Heidelberg, 2004. https://doi.org/10.1007/978-3-540-24673-2_21.

[3] Mallat, S. "Group Invariant Scattering."Communications in Pure and Applied Mathematics. Vol. 65, Number 10, 2012, pp. 1331–1398.

[4] Pujol, O., and P. Radeva. “Supervised Texture Classification for Intravascular Tissue Characterization.” InHandbook of Biomedical Image Analysis, edited by Jasjit S. Suri, David L. Wilson, and Swamy Laxminarayan, 57–109. Boston, MA: Springer US, 2005. https://doi.org/10.1007/0-306-48606-7_2.

[5] Sifre, L., and S. Mallat. "Rotation, scaling and deformation invariant scattering for texture discrimination."2013 IEEE Conference on Computer Vision and Pattern Recognition. 2013, pp 1233–1240. 10.1109/CVPR.2013.163.

[6]KTH-TIPS image databases homepage.https://www.csc.kth.se/cvap/databases/kth-tips/

Appendix — Supporting Functions

helperScatImages_mean

functionfeatures = helperScatImages_mean(sf,x) x = imresize(x,[200 200]); smat = featureMatrix(sf,x); features = mean(mean(smat,2),3);end

helperPCAModel

functionmodel = helperPCAModel(features,M,Labels)% This function is only to support wavelet image scattering examples in% Wavelet Toolbox. It may change or be removed in a future release.% model = helperPCAModel(features,M,Labels)% Copyright 2018 MathWorks% Initialize structure array to hold the affine modelmodel = struct('Dim',[],'mu',[],'U',[],'Labels',categorical([]),'s',[]); model.Dim = M;% Obtain the number of classesLabelCategories = categories(Labels); Nclasses = numel(categories(Labels));forkk = 1:Nclasses Class = LabelCategories{kk};% Find indices corresponding to each classidxClass = Labels == Class;% Extract feature vectors for each classtmpFeatures = features(:,idxClass);% Determine the mean for each classmodel.mu{kk} = mean(tmpFeatures,2); [model.U{kk},model.S{kk}] = scatPCA(tmpFeatures);ifsize(model.U{kk},2) > M model.U{kk} = model.U{kk}(:,1:M); model.S{kk} = model.S{kk}(1:M);endmodel.Labels(kk) = Class;endfunction[u,s,v] = scatPCA(x,M)% Calculate the principal components of x along the second dimension.ifnargin > 1 && M > 0% If M is non-zero, calculate the first M principal components.[u,s,v] = svds(x-sig_mean(x),M); s = abs(diag(s)/sqrt(size(x,2)-1)).^2;else% Otherwise, calculate all the principal components.% Each row is an observation, i.e. the number of scattering paths% Each column is a class observation[u,d] = eig(cov(x')); [s,ind] = sort(diag(d),'descend'); u = u(:,ind);endendend

helperPCAClassifier

functionlabels = helperPCAClassifier(features,model)% This function is only to support wavelet image scattering examples in% Wavelet Toolbox. It may change or be removed in a future release.% model is a structure array with fields, M, mu, v, and Labels% features is the matrix of test data which is Ns-by-L, Ns is the number of% scattering paths and L is the number of test examples. Each column of% features is a test example.% Copyright 2018 MathWorkslabelIdx = determineClass(features,model); labels = model.Labels(labelIdx);% Returns as column vector to agree with imageDatastore Labelslabels = labels(:);%--------------------------------------------------------------------------functionlabelIdx = determineClass(features,model)% Determine number of classesNclasses = numel(model.Labels);% Initialize error matrixerrMatrix = Inf(Nclasses,size(features,2));fornc = 1:Nclasses% class centroidmu = model.mu{nc}; u = model.U{nc};% 1-by-LerrMatrix(nc,:) = projectionError(features,mu,u);end% Determine minimum along class dimension[~,labelIdx] = min(errMatrix,[],1);%--------------------------------------------------------------------------functiontotalerr = projectionError(features,mu,u)%Npc = size(u,2); L = size(features,2);% Subtract class mean: Ns-by-L minus Ns-by-1s = features-mu;% 1-by-LnormSqX = sum(abs(s).^2,1)'; err = Inf(Npc+1,L); err(1,:) = normSqX; err(2:end,:) = -abs(u'*s).^2;% 1-by-Ltotalerr = sqrt(sum(err,1));endendend

See Also

Related Examples

More About