为了让训练结果可以复用,下面介绍如何将训练得到的网络模型持久化。
代码实现
tf.train.Saver
有关[tf.train.Saver
]类的官网文档见这里或者GitHub
简单实现
保存代码:
1 | import tensorflow as tf |
运行代码后输出
1 | 'Saved_model/model.ckpt' |
观察当前文件夹,新生成了Saved_model
文件夹,其中包含四个文件:
checkpoint
:保存了一个目录下所有的模型文件列表。model.ckpt.data-00000-of-00001
:保存了TensorFlow当前参数值。model.ckpt.index
:保存了TensorFlow当前参数名。model.ckpt.meta
:保存了TensorFlow计算图的结构。
加载代码:
1 | import tensorflow as tf |
输出如下:
1 | INFO:tensorflow:Restoring parameters from Saved_model/model.ckpt |
该代码首先定义了Tensorflow计算图中的所有运算结构,然后从本地文件中读入变量的值,不需要初始化变量。
加载持久化的图
若我们不希望代码中再次定义所有的结构,则可以加载已经保存了的图结构。代码如下:
1 | import tensorflow as tf |
输出如下
1 | INFO:tensorflow:Restoring parameters from Saved_model/model.ckpt |
上述所有代码,默认保存和加载了TensorFlow计算图中定义的全部变量。
保存指定变量
保存代码
1 | import tensorflow as tf |
上述程序会出错,报错信息如下:
1
2 > tensorflow.python.framework.errors_impl.FailedPreconditionError: Attempting to use uninitialized value v2
>
读取时对变量重命名
保存代码如下:
1 | import tensorflow as tf |
利用字典来重命名变量,key为结构图中的变量name,value为本地变量。加载代码如下:
1 | import tensorflow as tf |
保存和加载滑动平均模型
使用变量重命名方式
保存代码如下:
1 | import tensorflow as tf |
输出如下
1 | v:0 |
加载代码,因为滑动平均模型的特性,读取变量v的值,实际是要读取变量v的滑动平均值。
1 | import tensorflow as tf |
输出如下
1 | INFO:tensorflow:Restoring parameters from Saved_model/model2.ckpt |
使用variables_to_restore
为了方便加载时重命名滑动平均变量,tf.train.ExponentialMovingAverage
类提供了variables_to_restore
(Docs,Github)函数来生成tf.train.Saver
类所需要的变量重命名字典。
代码如下:
1 | import tensorflow as tf |
输出如下:
1 | {'v/ExponentialMovingAverage': <tf.Variable 'v:0' shape=() dtype=float32_ref>} |
PB文件保存
保存
1 | import tensorflow as tf |
graph_def = tf.get_default_graph().as_graph_def()
:导出当前计算图的GraphDef部分,只需要这一部分就可以完成从输入层到输出层的计算过程。
graph_util.convert_variables_to_constants
:将图中的变量和取值转化为常量。此时只生成了一个文件
combined_model.pb
。
输出
1 | INFO:tensorflow:Froze 2 variables. |
加载代码
1 | import tensorflow as tf |
输出
1 | [array([ 3.], dtype=float32)] |
持久化原理和数据格式
TensorFlow保存的文件为Protocol Buffer形式的。下面首页介绍这种格式的文件。
Protocol Buffer
Protocol Buffer是Google开发的处理结构化数据的工具。类似的还有XML、JSON。
比如需要保存以下的一些结构化信息:
1 | name: 张三 |
XML保存:
1 | <user> |
JSON保存
1 | { |
Protocol Buffer与这两者的区别:
- XML和JSON格式的数据,序列化后为可读的字符串,该字符串中包含所有信息。
- Protocol Buffer序列化后为不可读的二进制流,使用Protocol Buffer需先定义数据的格式(schema),还原数据时也需要相应的格式。
- Protocol Buffer序列化后的数据比XML或JSON小3到10倍,解析时间快20到100倍。
格式schema文件定义如下:
1 | message user{ |
.ckpt.meta —— MetaGraphDef
TensorFlow是一个通过图的形式来表述计算的编程系统,TensorFlow程序中的所有计算都会被表达为计算图上的节点。TensorFlow通过元图(MetaGraph)来记录计算图中节点的信息以及运行计算图中节点所需要的元数据。
类型定义如下,详见Github:
1 | message MetaGraphDef{ |
以上信息都保存在了model.ckpt.meta
文件中,此为二进制文件,无法直接查看。为了方便调试,TensorFlow提供了export_meta_graph
函数,支持以Json格式导出Protocol Buffer。代码如下
1 | import tensorflow as tf |
查看Json文件
1 | meta_info_def { |
meta_info_def属性
保存了Tensorflow计算图中的元数据和程序中所有用到的运算方法的信息。
定义如下:
1 | message MetaInfoDef { |
OpList
定义见Github
在OpDef中的attr属性中,必须包含name为T的属性,指定了运算输入输出允许的参数类型。
graph_def
主要记录计算图上的节点信息。
saver_def
主要记录持久化模型时需要用到的一些参数,比如保存到文件的文件名、保存操作和加载操作的名称以及保存频率、清理历史纪录等。
collection_def
维护不同的集合,是一个从集合名称到集合内容的映射。
.ckpt
TensorFlow采用tf.train.NewCheckpointReader
来读取ckpt文件中的所有变量信息。
1 | import tensorflow as tf |
tf.train.NewCheckpointReader
读取ckpt文件中的所有变量。variable_name
为变量名称all_variables[variable_name]
为变量维度
输出如下:
1 | v2 [1] |
checkpoint
tf.train.Saver
类自动生成且维护,记录所有Tensorflow模型文件的文件名。可读。
格式如下:
1 | message CheckpointState{ |
实例如下:
1 | model_checkpoint_path: "model.ckpt" |