Step 14:
Add
an input statement and a conditional statement. The input
statement asks the user if s/he would like to train the network, or
upload the previously trained network. If the user decides to
skip retraining the network, then the program should load the
previously trained and saved network (digitnet). Note that in
step 12, you saved the trained network in a file called digit net.
clear all;
close all;
clc;
trainagain=input('Would you like to train the network [y or n]? ','s');
while trainagain~='n' & trainagain~='y' % errortrap the input statement
trainagain=input('Would you like to train the network again [y or n]? ','s');
end;
if
strcmpi(trainagain,'y')==1 %if the user wants to retrain the
network, execute the following lines down to the "else" statement below.
[XTrain,YTrain] = digitTrain4DArrayData;
size(XTrain) %images
size(YTrain) %correct answer labels
XTrain=1-XTrain; % Reverse the black and white colors. Save and run the program to see the difference.
perm = randperm(size(XTrain,4),20); % Randomize the order of images in XTrain
for i = 1:20
subplot(4,5,i);
imshow(XTrain(:,:,:,perm(i)));
end
layers = [
imageInputLayer([28 28 1])
convolution2dLayer(3,8,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
convolution2dLayer(3,16,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
convolution2dLayer(3,32,'Padding','same')
batchNormalizationLayer
reluLayer
averagePooling2dLayer(7)
fullyConnectedLayer(10) % 10 output layer nodes
softmaxLayer
classificationLayer]; %close the bracket
options = trainingOptions('sgdm', ...
'InitialLearnRate',0.1, ...
'MaxEpochs',20, ...
'Verbose',false, ...
'Plots','training-progress', ...
'Shuffle','every-epoch' );
net1 = trainNetwork(XTrain,YTrain,layers,options);
save digitnet net1;
else
load digitnet; % If you don't want to retrain the network,
then load and use the previously trained network.
end; %if strcmpi
cam = webcam; % Connect to the camera
figure; % open new figure window
while true %this is a loop that will go on forever unless you break out of it.
im = snapshot(cam); % Take a picture
image(im); % Show the picture
im=rgb2gray(im); %make the image grayscale (that's what the network is expecting)
im=round(double(im)/255); % change from 0-255 integer to 0-1 double (increase contrast by using round)
im = imresize(im,[28 28]); % Resize the picture for the network you trained (it's expecting a 28x28 image)
label
= classify(net1,im); % Classify the picture. Type help classify
in the command window to get more information about this command.
title(char(label),'fontsize',18); % Show the class label
drawnow %force matlab to immediately display the image and label
end
% Copy and paste this whole program in a new editor window and test it out.