使用 Graph.finalize() 來捕獲新增到圖中的節點
使用 TensorFlow 的最常見模式包括首先構建 TensorFlow 運算子的資料流圖(如 tf.constant()
和 tf.matmul()
,然後通過在迴圈中呼叫 tf.Session.run()
方法(例如訓練迴圈)來執行步驟 )。
記憶體洩漏的常見來源是訓練迴圈包含將節點新增到圖形的呼叫,並且這些呼叫在每次迭代中執行,從而導致圖形增長。這些可能是顯而易見的(例如,呼叫 TensorFlow 運算子,如 tf.square()
),隱式(例如呼叫 TensorFlow 庫函式建立運算子,如 tf.train.Saver()
),或微妙(例如呼叫 tf.Tensor
和 NumPy 陣列上的過載運算子) ,隱含地呼叫 tf.convert_to_tensor()
並向圖中新增新的 tf.constant()
。
該 tf.Graph.finalize()
方法可以幫助趕上這樣的洩漏:它標誌著一個圖形為只讀,如果有什麼被新增到圖中引發了異常。例如:
loss = ...
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
init = tf.initialize_all_variables()
with tf.Session() as sess:
sess.run(init)
sess.graph.finalize() # Graph is read-only after this statement.
for _ in range(1000000):
sess.run(train_op)
loss_sq = tf.square(loss) # Exception will be thrown here.
sess.run(loss_sq)
在這種情況下,過載的*
運算子會嘗試向圖中新增新節點:
loss = ...
# ...
with tf.Session() as sess:
# ...
sess.graph.finalize() # Graph is read-only after this statement.
# ...
dbl_loss = loss * 2.0 # Exception will be thrown here.