為什麼要使用 tf.py func

tf.py_func() 運算子使你可以在 TensorFlow 圖的中間執行任意 Python 程式碼。包裝自定義 NumPy 運算子特別方便,因為沒有等效的 TensorFlow 運算子(尚未存在)。新增 tf.py_func() 是在圖形中使用 sess.run() 呼叫的替代方法。

另一種方法是將圖形分為兩部分:

# Part 1 of the graph
inputs = ...  # in the TF graph

# Get the numpy array and apply func
val = sess.run(inputs)  # get the value of inputs
output_val = func(val)  # numpy array

# Part 2 of the graph
output = tf.placeholder(tf.float32, shape=...)
train_op = ...

# We feed the output_val to the tensor output
sess.run(train_op, feed_dict={output: output_val})

使用 tf.py_func,這更容易:

# Part 1 of the graph
inputs = ...

# call to tf.py_func
output = tf.py_func(func, [inputs], [tf.float32])[0]

# Part 2 of the graph
train_op = ...

# Only one call to sess.run, no need of a intermediate placeholder
sess.run(train_op)