使用 Graph.finalize() 來捕獲新增到圖中的節點

使用 TensorFlow 的最常見模式包括首先構建 TensorFlow 運算子的資料流圖(如 tf.constant()tf.matmul(),然後通過在迴圈中呼叫 tf.Session.run() 方法(例如訓練迴圈)來執行步驟 )。

記憶體洩漏的常見來源是訓練迴圈包含將節點新增到圖形的呼叫,並且這些呼叫在每次迭代中執行,從而導致圖形增長。這些可能是顯而易見的(例如,呼叫 TensorFlow 運算子,如 tf.square()),隱式(例如呼叫 TensorFlow 庫函式建立運算子,如 tf.train.Saver()),或微妙(例如呼叫 tf.Tensor 和 NumPy 陣列上的過載運算子) ,隱含地呼叫 tf.convert_to_tensor() 並向圖中新增新的 tf.constant()

tf.Graph.finalize() 方法可以幫助趕上這樣的洩漏:它標誌著一個圖形為只讀,如果有什麼被新增到圖中引發了異常。例如:

loss = ...
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
init = tf.initialize_all_variables()

with tf.Session() as sess:
    sess.run(init)
    sess.graph.finalize()  # Graph is read-only after this statement.

    for _ in range(1000000):
        sess.run(train_op)
        loss_sq = tf.square(loss)  # Exception will be thrown here.
        sess.run(loss_sq)

在這種情況下,過載的*運算子會嘗試向圖中新增新節點:

loss = ...
# ...
with tf.Session() as sess:
    # ...
    sess.graph.finalize()  # Graph is read-only after this statement.
    # ...
    dbl_loss = loss * 2.0  # Exception will be thrown here.