介绍一些神经网络中常用的优化方法。包括动态学习率、正则化防止过拟合、滑动平均模型。
优化方法
学习率的设置
TensorFlow提供了一种学习率设置方法——指数衰减法。全部方法见Decaying_the_learning_rate
tf.train.exponential_decay
函数先使用较大的学习率来快速得到较优解,然后随着训练步数的增多,学习率逐步降低,最后进行微调。
该函数的官网介绍:exponential_decay、GitHub介绍。
函数定义如下:
1 | exponential_decay( |
该函数通过步数动态的计算学习率,计算公式如下:
1 | decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps) |
其中:
decayed_learning_rate
: 每一轮优化时采用的学习率learning_rate
: 实现设定的初始学习率decay_rate
: 衰减系数decay_steps
: 衰减速度
staircase
如果为True
,则global_step / decay_steps
会转化为整数,这使得学习率变化函数从平滑曲线变为了阶梯函数(staircase function)。
示例代码如下:
1 | BATCH_SIZE = 100 |
完整代码见后文。
通过正则化来避免过拟合问题
有关过拟合,即通过训练得到的模型不具有通用性,在测试集上表现不佳。
为了避免过拟合,常用方法有dropout、正则化(regularization)。
正则化的思想就是在损失函数中加入刻画模型复杂度的指标。如何刻画复杂度,一般做法是对权重W进行正则化。
正则化包含L1正则化和L2正则化。多用L2正则化。原文对比见github
L2比L1好用的原因:
- L1使得参数变得更稀疏,即一些参数会变为0。L2会使得参数保持一个很小的数字,比如0.001。
- L1正则化公式不可导,L2正则化公式可导。
L1和L2的测试代码如下:
1 | import tensorflow as tf |
输出如下:
1 | 5.0 |
计算方法如下:
1 | L1 = (|1| + |-2| + |-3| + |4|) * 0.5 = 5 |
0.5为正则化项的权重lambda。TensorFlow将L2的正则化损失值除以2使求导得到的结果更加简洁。
通过集合Collection解决层数过多时代码过长问题
思路:将所有的权重向量加入到一个集合中,最后累加这个集合中的变量。
示例,构建5层神经网络代码如下:
1 | import tensorflow as tf |
滑动平均模型
在采用随即梯度下降算法训练时,使用滑动平均模型会在一定程度上提供最终模型在测试数据上的性能。
类tf.train.ExponentialMovingAverage
(DOC、GitHub)
初始化函数
1 | __init__( |
其中decay
为衰减率。用来控制模型更新的速度。区间为(0,1)。一般取0.999
,0.9999
等。越大模型越趋于稳定。
num_updates
用来动态设置decay的大小。一般情况下可以用训练步数作为此参数。如果设置了该变量,则衰减率一开始会比较小,后期会越来越大。每次的衰减率计算公式为:min(decay, (1 + num_updates) / (10 + num_updates))
1 | apply(var_list=None) |
ExponentialMovingAverage
通过调用apply
会对参数列表中的每一个Variable
维护一个影子变量(shadow variables)。影子变量的初始值是相应变量的初始值,当每次模型更新时,影子变量的值会更新为:shadow_variable = decay * shadow_variable + (1 - decay) * variable
文档示例代码
1 | # Create variables. |
其他示例见Github
实践——MINST
使用上文所属的三种方法来优化MNIST全连接神经网络,代码如下:
原文代码地址github
1 | import tensorflow as tf |
inference函数中最后输出未使用softmax的原因:在计算损失函数时会一并计算softmax函数,所以这里不需要激活函数。而且不加入softmax不会影响预测结果。因为预测时采用的是不同类别对应节点输出值的相对大小,有没有softmax层对最后分类结果的计算没有影响。于是在计算整个神经网络的前向传播时可以不加入最后的softmax层。
输出如下:
1 | Extracting mnist_data/train-images-idx3-ubyte.gz |