使用批处理读取 n 个时期的数据

假设你的数据示例已经读取到 python 的变量,并且你希望以给定大小的批量读取 n 次:

import numpy as np
import tensorflow as tf
data = np.array([1, 2, 3, 4, 5])
n = 4

要批量合并数据,可能使用随机改组,你可以使用 tf.train.batchtf.train.batch_shuffle,但你需要传递一个会产生 n 次全数据的张量:

limited_tensor = tf.train.limit_epochs(data, n)
batch = tf.train.shuffle_batch([limited_tensor], batch_size=3, enqueue_many=True, capacity=4)

limit_epochs 将 numpy 数组转换为引擎盖下的张量并返回一个张量,产生 n 次并随后抛出 OutOfRangeError。传递给 shuffle_batchenqueue_many=True 参数表示张量列表 [limited_tensor] 中的每个张量应该被解释为包含许多示例。请注意,批处理队列的容量可能小于张量中的示例数。

人们可以像往常一样处理数据:

with tf.Session() as sess:
  sess.run(tf.initialize_local_variables())
  tf.train.start_queue_runners()
  try:
    while True:
      data_batch = sess.run(batch)
      # process data
  except tf.errors.OutOfRangeError:
    pass