本文的代码来自githup的Deep Learning的toolbox,是用Matlab实现的。感谢该toolbox的作者付出和分享。
我在应该该代码进行训练时,出现一些报错,如expand函数应用不对,sigm函数和flipall函数未定义等问题,对这些问题进行了修正,完成网络训练和验证。
本文mnist_uint8.mat的获取可以参照我的另一篇博客:MNIST数据库处理--matlab生成mnist_uint8.mat http://blog.csdn.net/fuwenyan/article/details/53954615
本文作者:非文艺小燕儿,VivienFu,
欢迎大家转载适用,请注明出处。http://blog.csdn.net/fuwenyan?viewmode=contents http://blog.csdn.net/fuwenyan?viewmode=contents
学习该代码需要一定的CNN理论基础。
目的:实现手写数字识别
数据集:MNIST数据集,60000张训练图像,10000张测试图像,每张图像size为28*28
网络层级结构概述:5层神经网络
Input layer: 输入数据为原始训练图像
Conv1:6个5*5的卷积核,步长Stride为1
Pooling1:卷积核size为2*2,步长Stride为2
Conv2:12个5*5的卷积核,步长Stride为1
Pooling2:卷积核size为2*2,步长Stride为2
Output layer:输出为10维向量
网络层级结构示意图如下:
代码流程概述:
(1)获取训练数据和测试数据;
(2)定义网络层级结构;
(3)初始设置网络参数(权重W,偏向b)cnnsetup(cnn, train_x, train_y)
(4)训练超参数opts定义(学习率,batchsize,epoch)
(5)网络训练之前向运算cnnff(net, batch_x)
(6)网络训练之反向传播cnnbp(net, batch_y)
(7)网络训练之参数更新cnnapplygrads(net, opts)
(8)重复(5)(6)(7),直至满足epoch
(9)网络测试cnntest(cnn, test_x, test_y)
详细代码及我的注释
cnnexamples.m
clear all; close all; clc;
load mnist_uint8;
train_x = double(reshape(train_x',28,28,60000))/255; %数据归一化至[0 1]之间
test_x = double(reshape(test_x',28,28,10000))/255; %数据归一化至[0 1]之间
train_y = double(train_y');
test_y = double(test_y');
%% ex1
%will run 1 epoch in about 200 second and get around 11% error.
%With 100 epochs you'll get around 1.2% error
cnn.layers = {
struct('type', 'i') %input layer
struct('type', 'c', 'outputmaps', 6, 'kernelsize', 5) %convolution layer,6个5*5的卷积核,可以得到6个outputmaps
struct('type', 's', 'scale', 2) %sub sampling layer ,2*2的下采样卷积核
struct('type', 'c', 'outputmaps', 12, 'kernelsize', 5) %convolution layer ,12个5*5的卷积核,可以得到12个outputmaps
struct('type', 's', 'scale', 2) %subsampling layer ,2*2的下采样卷积核
}; %定义了一个5层神经网络,还有一个输出层,并未在这里定义。
cnn = cnnsetup(cnn, train_x, train_y); %通过该函数,对网络初始权重矩阵和偏向进行初始化
opts.alpha = 1; % 学习率
opts.batchsize = 50; %每batchsize张图像一起训练一轮,调整一次权值。
% 训练次数,用同样的样本集。我训练的时候:
% 1的时候 11.41% error
% 5的时候 4.2% error
% 10的时候 2.73% error
opts.numepochs = 10; %每个epoch内,对所有训练数据进行训练,更新(训练图像个数/batchsize)次网络参数
cnn = cnntrain(cnn,