使用批處理讀取 n 個時期的資料

假設你的資料示例已經讀取到 python 的變數,並且你希望以給定大小的批量讀取 n 次:

import numpy as np
import tensorflow as tf
data = np.array([1, 2, 3, 4, 5])
n = 4

要批量合併資料,可能使用隨機改組,你可以使用 tf.train.batchtf.train.batch_shuffle,但你需要傳遞一個會產生 n 次全資料的張量:

limited_tensor = tf.train.limit_epochs(data, n)
batch = tf.train.shuffle_batch([limited_tensor], batch_size=3, enqueue_many=True, capacity=4)

limit_epochs 將 numpy 陣列轉換為引擎蓋下的張量並返回一個張量,產生 n 次並隨後丟擲 OutOfRangeError。傳遞給 shuffle_batchenqueue_many=True 參數列示張量列表 [limited_tensor] 中的每個張量應該被解釋為包含許多示例。請注意,批處理佇列的容量可能小於張量中的示例數。

人們可以像往常一樣處理資料:

with tf.Session() as sess:
  sess.run(tf.initialize_local_variables())
  tf.train.start_queue_runners()
  try:
    while True:
      data_batch = sess.run(batch)
      # process data
  except tf.errors.OutOfRangeError:
    pass