使用張量的類似 Numpy 的索引

這個例子基於這篇文章: TensorFlow - 類似 numpy 的張量索引

在 Numpy 中,你可以使用陣列索引到陣列中。例如,為了在二維陣列中選擇 (1, 2)(3, 2) 中的元素,你可以這樣做:

# data is [[0, 1, 2, 3, 4, 5],
#          [6, 7, 8, 9, 10, 11],
#          [12 13 14 15 16 17],
#          [18 19 20 21 22 23],
#          [24, 25, 26, 27, 28, 29]]
data = np.reshape(np.arange(30), [5, 6])
a = [1, 3]
b = [2, 2]
selected = data[a, b]
print(selected)

這將列印:

[ 8 20]

要在 Tensorflow 中獲得相同的行為,你可以使用 tf.gather_nd ,它是 tf.gather 的擴充套件。上面的例子可以這樣寫:

x = tf.constant(data)
idx1 = tf.constant(a)
idx2 = tf.constant(b)
result = tf.gather_nd(x, tf.stack((idx1, idx2), -1))
        
with tf.Session() as sess:
    print(sess.run(result))

這將列印:

[ 8 20]

tf.stack 相當於 np.asarray,在這種情況下,沿著最後一個維度(在本例中為第一維)堆疊兩個索引向量,以產生:

[[1 2]
 [3 2]]