使用张量的类似 Numpy 的索引

这个例子基于这篇文章: TensorFlow - 类似 numpy 的张量索引

在 Numpy 中,你可以使用数组索引到数组中。例如,为了在二维数组中选择 (1, 2)(3, 2) 中的元素,你可以这样做:

# data is [[0, 1, 2, 3, 4, 5],
#          [6, 7, 8, 9, 10, 11],
#          [12 13 14 15 16 17],
#          [18 19 20 21 22 23],
#          [24, 25, 26, 27, 28, 29]]
data = np.reshape(np.arange(30), [5, 6])
a = [1, 3]
b = [2, 2]
selected = data[a, b]
print(selected)

这将打印:

[ 8 20]

要在 Tensorflow 中获得相同的行为,你可以使用 tf.gather_nd ,它是 tf.gather 的扩展。上面的例子可以这样写:

x = tf.constant(data)
idx1 = tf.constant(a)
idx2 = tf.constant(b)
result = tf.gather_nd(x, tf.stack((idx1, idx2), -1))
        
with tf.Session() as sess:
    print(sess.run(result))

这将打印:

[ 8 20]

tf.stack 相当于 np.asarray,在这种情况下,沿着最后一个维度(在本例中为第一维)堆叠两个索引向量,以产生:

[[1 2]
 [3 2]]