# -*- coding:utf-8 -*-
'''
Created on 2018-1-19
'''
import numpy as np
import tflearn
import tflearn.datasets.mnist as mnist
X, Y, testX, testY = mnist.load_data(one_hot=True)
def do_DNN(X, Y, testX, testY):
# Building deep neural network
# 定义输入层
input_layer = tflearn.input_data(shape=[None, 784])
# 定义一个64神经元的全连接隐藏层,
dense1 = tflearn.fully_connected(input_layer, 64, activation='tanh',regularizer='L2', weight_decay=0.001)
# 对全连接使用一个dropout,随机停用部分神经元,防止过拟合
dropout1 = tflearn.dropout(dense1, 0.8)
dense2 = tflearn.fully_connected(dropout1, 64, activation='tanh',regularizer='L2', weight_decay=0.001)
dropout2 = tflearn.dropout(dense2, 0.8)
# 定义输出层,使用softmax分类
softmax = tflearn.fully_connected(dropout2, 10, activation='softmax')
# Regression using SGD with learning rate decay and Top-3 accuracy
# 反向传播使用SGD随机梯度下降法,学习率递减
sgd = tflearn.SGD(learning_rate=0.1, lr_decay=0.96, decay_step=1000)
# 定义,真实结果在预测结果前3中就算正确
top_k = tflearn.metrics.Top_k(3)
# 定义回归策略
net = tflearn.regression(softmax, optimizer=sgd, metric=top_k,loss='categorical_crossentropy')
# Training
model = tflearn.DNN(net, tensorboard_verbose=0)
model.fit(X, Y, n_epoch=20, validation_set=(testX, testY),
show_metric=True, run_id="dense_model")
# 准确率98%
do_DNN(X, Y, testX, testY)
MNIST数据集的下载
MNIST是一些手写数字的图片,通过http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz下载数据集。