TensorFlow实现卷积神经网络CNN

一、卷积神经网络CNN简介

卷积神经网络(ConvolutionalNeuralNetwork,CNN)最初是为解决图像识别等问题设计的,CNN现在的应用已经不限于图像和视频,也可用于时间序列信号,比如音频信号和文本数据等。CNN作为一个深度学习架构被提出的最初诉求是降低对图像数据预处理的要求,避免复杂的特征工程。在卷积神经网络中,第一个卷积层会直接接受图像像素级的输入,每一层卷积(滤波器)都会提取数据中最有效的特征,这种方法可以提取到图像中最基础的特征,而后再进行组合和抽象形成更高阶的特征,因此CNN在理论上具有对图像缩放、平移和旋转的不变性。

卷积神经网络CNN的要点就是局部连接(LocalConnection)、权值共享(WeightsSharing)和池化层(Pooling)中的降采样(Down-Sampling)。其中,局部连接和权值共享降低了参数量,使训练复杂度大大下降并减轻了过拟合。同时权值共享还赋予了卷积网络对平移的容忍性,池化层降采样则进一步降低了输出参数量并赋予模型对轻度形变的容忍性,提高了模型的泛化能力。可以把卷积层卷积操作理解为用少量参数在图像的多个位置上提取相似特征的过程。

更多请参见:深度学习之卷积神经网络CNN

二、TensorFlow代码实现

#!/usr/bin/env python2 
# -*- coding: utf-8 -*- 
""" 
Created on Thu Mar 9 22:01:46 2017 
 
@author: marsjhao 
""" 
 
import tensorflow as tf 
from tensorflow.examples.tutorials.mnist import input_data 
 
mnist = input_data.read_data_sets("MNIST_data/",one_hot=True) 
sess = tf.InteractiveSession() 
 
def weight_variable(shape): 
 initial = tf.truncated_normal(shape,stddev=0.1) #标准差为0.1的正态分布 
 return tf.Variable(initial) 
 
def bias_variable(shape): 
 initial = tf.constant(0.1,shape=shape) #偏差初始化为0.1 
 return tf.Variable(initial) 
 
def conv2d(x,W): 
 return tf.nn.conv2d(x,W,strides=[1,1,1],padding='SAME') 
 
def max_pool_2x2(x): 
 return tf.nn.max_pool(x,ksize=[1,2,padding='SAME') 
 
x = tf.placeholder(tf.float32,[None,784]) 
y_ = tf.placeholder(tf.float32,10]) 
# -1代表先不考虑输入的图片例子多少这个维度,1是channel的数量 
x_image = tf.reshape(x,[-1,28,1]) 
keep_prob = tf.placeholder(tf.float32) 
 
# 构建卷积层1 
W_conv1 = weight_variable([5,5,32]) # 卷积核5*5,1个channel,32个卷积核,形成32个featuremap 
b_conv1 = bias_variable([32]) # 32个featuremap的偏置 
h_conv1 = tf.nn.relu(conv2d(x_image,W_conv1) + b_conv1) # 用relu非线性处理 
h_pool1 = max_pool_2x2(h_conv1) # pooling池化 
 
# 构建卷积层2 
W_conv2 = weight_variable([5,32,64]) # 注意这里channel值是32 
b_conv2 = bias_variable([64]) 
h_conv2 = tf.nn.relu(conv2d(h_pool1,W_conv2) + b_conv2) 
h_pool2 = max_pool_2x2(h_conv2) 
 
# 构建全连接层1 
W_fc1 = weight_variable([7*7*64,1024]) 
b_fc1 = bias_variable([1024]) 
h_pool3 = tf.reshape(h_pool2,7*7*64]) 
h_fc1 = tf.nn.relu(tf.matmul(h_pool3,W_fc1) + b_fc1) 
h_fc1_drop = tf.nn.dropout(h_fc1,keep_prob) 
 
# 构建全连接层2 
W_fc2 = weight_variable([1024,10]) 
b_fc2 = bias_variable([10]) 
y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop,W_fc2) + b_fc2) 
 
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y_conv),reduction_indices=[1])) 
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) 
correct_prediction = tf.equal(tf.arg_max(y_conv,1),tf.arg_max(y_,1)) 
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) 
 
tf.global_variables_initializer().run() 
 
for i in range(20001): 
 batch = mnist.train.next_batch(50) 
 if i % 100 == 0: 
  train_accuracy = accuracy.eval(feed_dict={x:batch[0],y_:batch[1],keep_prob: 1.0}) 
  print("step %d,training accuracy %g" %(i,train_accuracy)) 
 train_step.run(feed_dict={x: batch[0],y_: batch[1],keep_prob:0.5}) 
print("test accuracy %g" %accuracy.eval(feed_dict={x: mnist.test.images,y_: mnist.test.labels,keep_prob: 1.0})) 

三、代码解读

该代码是用TensorFlow实现一个简单的卷积神经网络,在数据集MNIST上,预期可以实现99.2%左右的准确率。结构上使用两个卷积层和一个全连接层。

首先载入MNIST数据集,采用独热编码,并创建tf.InteractiveSession。然后为后续即将多次使用的部分代码创建函数,包括权重初始化weight_variable、偏置初始化bias_variable、卷积层conv2d、最大池化max_pool_2x2。其中权重初始化的时候要进行含有噪声的非对称初始化,打破完全对称。又由于我们要使用ReLU单元,也需要给偏置bias增加一些小的正值(0.1)用来避免死亡节点(dead neurons)。

构建卷积神经网络之前,先要定义输入的placeholder,特征x和真实标签y_,将1*784格式的特征x转换reshape为28*28的图片格式,又由于只有一个通道且不确定输入样本的数量,故最终尺寸为[-1,1]。

接下来定义第一个卷积层,首先初始化weights和bias,然后使用conv2d进行卷积操作并加上偏置,随后使用ReLU激活函数进行非线性处理,最后使用最大池化函数对卷积的输出结果进行池化操作。

相同的步骤定义第二个卷积层,不同的地方是卷积核的数量为64,也就是说这一层的卷积会提取64种特征。经过两层不变尺寸的卷积和两次尺寸减半的池化,第二个卷积层后的输出尺寸为7*7*64。将其reshape为长度为7*7*64的1-D向量。经过ReLU后,为了减轻过拟合,使用一个Dropout层,在训练时随机丢弃部分节点的数据减轻过拟合,在预测的时候保留全部数据来追求最好的测试性能。

最后加一个Softmax层,得到最后的预测概率。随后的定义损失函数、优化器、评测准确率不再详细赘述。

训练过程首先进行初始化全部参数,训练时keep_prob比率设置为0.5,评测时设置为1。训练完成后,在最终的测试集上进行全面的测试,得到整体的分类准确率。

经过实验,这个CNN的模型可以得到99.2%的准确率,相比于MLP又有了较大幅度的提高。

四、其他解读补充

1. tf.nn.conv2d(x,padding='SAME')

tf.nn.conv2d是TensorFlow的2维卷积函数,x和W都是4-D的tensors。x是输入input shape=[batch,in_height,in_width,in_channels],W是卷积的参数filter / kernel shape=[filter_height,filter_width,in_channels,out_channels]。strides参数是长度为4的1-D参数,代表了卷积核(滑动窗口)移动的步长,其中对于图片strides[0]和strides[3]必须是1,都是1表示不遗漏地划过图片的每一个点。padding参数中SAME代表给边界加上Padding让卷积的输出和输入保持相同的尺寸。

2. tf.nn.max_pool(x,padding='SAME')

tf.nn.max_pool是TensorFlow中的最大池化函数,x是4-D的输入tensor shape=[batch,height,width,channels],ksize参数表示池化窗口的大小,取一个4维向量,一般是[1,1],因为我们不想在batch和channels上做池化,所以这两个维度设为了1,strides与tf.nn.conv2d相同,strides=[1,1]可以缩小图片尺寸。padding参数也参见tf.nn.conv2d。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


本文从多个角度分析了vi编辑器保存退出命令。我们介绍了保存和退出vi编辑器的命令,以及如何撤销更改、移动光标、查找和替换文本等实用命令。希望这些技巧能帮助你更好地使用vi编辑器。
Python中的回车和换行是计算机中文本处理中的两个重要概念,它们在代码编写中扮演着非常重要的角色。本文从多个角度分析了Python中的回车和换行,包括回车和换行的概念、使用方法、使用场景和注意事项。通过本文的介绍,读者可以更好地理解和掌握Python中的回车和换行,从而编写出更加高效和规范的Python代码。
SQL Server启动不了错误1067是一种比较常见的故障,主要原因是数据库服务启动失败、权限不足和数据库文件损坏等。要解决这个问题,我们需要检查服务日志、重启服务器、检查文件权限和恢复数据库文件等。在日常的数据库运维工作中,我们应该时刻关注数据库的运行状况,及时发现并解决问题,以确保数据库的正常运行。
信息模块是一种可重复使用的、可编程的、可扩展的、可维护的、可测试的、可重构的软件组件。信息模块的端接需要从接口设计、数据格式、消息传递、函数调用等方面进行考虑。信息模块的端接需要满足高内聚、低耦合的原则,以保证系统的可扩展性和可维护性。
本文从电脑配置、PyCharm版本、Java版本、配置文件以及程序冲突等多个角度分析了Win10启动不了PyCharm的可能原因,并提供了解决方法。
本文主要从多个角度分析了安装SQL Server 2012时可能出现的错误,并提供了解决方法。
Pycharm是一款非常优秀的Python集成开发环境,它可以让Python开发者更加高效地进行代码编写、调试和测试。在Pycharm中设置解释器非常简单,我们可以通过创建新项目、修改项目解释器、设置全局解释器等多种方式进行设置。
Python中有多种方法可以将字符串转换为整数,包括使用int()函数、try-except语句、正则表达式、map()函数、ord()函数和reduce()函数。在实际应用中,应根据具体情况选择最合适的方法。
本文介绍了导入CSV文件的多种方法,包括使用Excel、Python和R等工具。同时,还介绍了导入CSV文件时需要注意的一些细节和问题。CSV文件是数据处理和分析中不可或缺的一部分,希望本文能够对读者有所帮助。
mongodb是一种新型的数据库,它采用了面向文档的数据模型,具有灵活性、高性能和高可用性等优势。但是,mongodb也存在数据结构混乱、安全性和学习成本高等问题。
当Python运行不了时,我们应该从代码、Python环境、操作系统和硬件设备等多个角度来排查问题,并采取相应的解决措施。
Python列表是一种常见的数据类型,排序是列表操作中的一个重要部分。本文介绍了Python列表降序排序的方法,包括使用sort()函数、sorted()函数以及自定义函数进行排序。使用sort()函数可以简单方便地实现降序排序,但会改变原始列表的顺序;使用sorted()函数可以保留原始列表的顺序,但需要创建一个新的列表;使用自定义函数可以灵活地控制排序的方式,但需要编写额外的代码。
本文介绍了如何使用Python输入一段英文并统计其中的单词个数,从去除标点符号、忽略单词大小写、排除常用词汇等多个角度进行了分析。此外,还介绍了使用NLTK库进行单词统计的方法。
虚拟环境可以帮助我们在同一台机器上运行不同版本的Python、安装不同的Python包,并且不会相互影响。创建虚拟环境的命令是python3 -m venv myenv,进入虚拟环境的命令是source myenv/bin/activate,退出虚拟环境的命令是deactivate。在虚拟环境中可以使用pip安装包,也可以使用Python运行程序。
本文从XHR对象、fetch API和jQuery三个方面分析了JS获取响应状态的方法及其应用。以上三种方法都可以轻松地发送HTTP请求,并处理响应数据。
桌面的命令包括常见的操作命令、系统命令、批处理命令以及第三方应用程序提供的命令。我们可以通过鼠标右键点击桌面、创建快捷方式、创建批处理文件等方式来运用这些命令,从而更好地管理计算机,提高工作效率。
本文分析了应用程序闪退的多个原因,包括应用程序本身存在问题、手机或平板电脑系统问题、硬件问题、网络问题和其他原因。同时,本文提供了解决闪退问题的多种方式,包括更新或卸载重新下载应用程序、升级系统或进行修复、清理手机缓存、清理不必要的文件或者是更换电池等方式来解决、确保网络信号的稳定性、注意用户隐私和安全问题。
本文介绍了使用Python下载图片的多种方法,包括使用Python标准库urllib.request、第三方库requests、多线程和异步IO。这些方法在不同情况下都有它们的优缺点。使用这些方法,我们可以轻松地将网络上的图片下载到本地,方便我们在离线状态下查看或处理这些图片。
MySQL数据文件是指存储MySQL数据库中数据的文件,存储位置的选择对数据库的性能、可靠性和安全性都有着重要的影响。本文从存储位置的选择、存储设备的选择、存储空间的管理和存储位置的安全性等多个角度对MySQL数据文件的存储位置进行分析,最后得出需要根据实际情况综合考虑多个因素,选择合适的存储位置和存储设备,并进行有效的存储空间管理和安全措施的结论。
AS400是一种主机操作系统,每个库都包含多个表。查询库表总数是一项基本任务。可以使用命令行、系统管理界面以及数据库管理工具来查询库表总数。查询库表总数可以帮助用户更好地管理和优化数据,包括规划数据存储、优化查询性能以及管理空间资源。