保存模型

在张量流中保存模型非常简单。

假设你有一个输入 x 的线性模型,并想要预测输出 y。这里的损失是均方误差(MSE)。批量大小为 16。

# Define the model
x = tf.placeholder(tf.float32, [16, 10])  # input
y = tf.placeholder(tf.float32, [16, 1])   # output

w = tf.Variable(tf.zeros([10, 1]), dtype=tf.float32)

res = tf.matmul(x, w)
loss = tf.reduce_sum(tf.square(res - y))

train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)

这里有 Saver 对象,它可以有多个参数(参见 doc )。

# Define the tf.train.Saver object
# (cf. params section for all the parameters)    
saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=1)

最后,我们在 tf.Session() 中训练模型,进行 1000 迭代。我们只在每个 100 迭代中保存模型。

# Start a session
max_steps = 1000
with tf.Session() as sess:
    # initialize the variables
    sess.run(tf.initialize_all_variables())

    for step in range(max_steps):
        feed_dict = {x: np.random.randn(16, 10), y: np.random.randn(16, 1)}  # dummy input
        _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)

        # Save the model every 100 iterations
        if step % 100 == 0:
            saver.save(sess, "./model", global_step=step)

运行此代码后,你应该看到目录中的最后 5 个检查点:

  • model-500model-500.meta
  • model-600model-600.meta
  • model-700model-700.meta
  • model-800model-800.meta
  • model-900model-900.meta

请注意,在这个例子中,虽然 saver 实际上保存了变量的当前值作为检查点和图形的结构(*.meta),但是没有特别注意如何检索例如占位符 xy 一旦模型是恢复。例如,如果在此训练脚本以外的任何地方进行恢复,则从恢复的图形中检索 xy 可能很麻烦(特别是在更复杂的模型中)。为了避免这种情况,请始终为变量/占位符/操作命名,或者考虑使用 tf.collections,如其中一个备注所示。