如何使用 tf.gather nd

tf.gather_nd 是扩展 tf.gather 的,因为它可以让你不仅可以访问一个张量的 1 维的感觉,但可能所有的人。

参数:

  • params:一个等级 P 的张量,代表我们想要索引的张量
  • indices:一个等级 Q 的张量,代表我们想要访问的 params 的索引

功能的输出取决于 indices 的形状。如果 indices 的最内层尺寸为 P,我们正在从 params 收集单个元素。如果它小于 P,我们正在收集切片,就像 tf.gather 一样但没有限制我们只能访问第一维。

从等级 2 的张量中收集元素

要在矩阵中访问 (1, 2) 中的元素,我们可以使用:

# data is [[0, 1, 2, 3, 4, 5],
#          [6, 7, 8, 9, 10, 11],
#          [12 13 14 15 16 17],
#          [18 19 20 21 22 23],
#          [24, 25, 26, 27, 28, 29]]
data = np.reshape(np.arange(30), [5, 6])
x = tf.constant(data)
result = tf.gather_nd(x, [1, 2])

result 将如预期的那样成为 8。请注意这与 tf.gather 有何不同:传递给 tf.gather(x, [1, 2]) 的相同索引将作为 data 的第 2 和第 3 给出。

如果要同时检索多个元素,只需传递一个索引对列表:

result = tf.gather_nd(x, [[1, 2], [4, 3], [2, 5]])

这将返回 [ 8 27 17]

从等级 2 的张量中收集行

如果在上面的示例中你想要收集行(即切片)而不是元素,请按如下方式调整 indices 参数:

data = np.reshape(np.arange(30), [5, 6])
x = tf.constant(data)
result = tf.gather_nd(x, [[1], [3]])

这将给你第 2 和第 4 行 data,即

[[ 6  7  8  9 10 11]
 [18 19 20 21 22 23]]

从第 3 级的张量中收集元素

如何访问秩 -2 张量的概念直接转换为更高维度的张量。因此,要访问 rank-3 张量中的元素,indices 的最内层维度必须为 3。

# data is [[[ 0  1]
#          [ 2  3]
#          [ 4  5]]
#
#         [[ 6  7]
#          [ 8  9]
#          [10 11]]]
data = np.reshape(np.arange(12), [2, 3, 2])
x = tf.constant(data)
result = tf.gather_nd(x, [[0, 0, 0], [1, 2, 1]])

result 现在看起来像这样:[ 0 11]

从 3 级的张量中收集批量行

让我们把秩 -3 张量想象成一批形状为 tihuan 的矩阵 26。如果要为批处理中的每个元素收集第一行和第二行,可以使用:

# data is [[[ 0  1]
#          [ 2  3]
#          [ 4  5]]
#
#         [[ 6  7]
#          [ 8  9]
#          [10 11]]]
data = np.reshape(np.arange(12), [2, 3, 2])
x = tf.constant(data)
result = tf.gather_nd(x, [[[0, 0], [0, 1]], [[1, 0], [1, 1]]])

这将导致:

[[[0 1]
  [2 3]]

 [[6 7]
  [8 9]]]

注意 indices 的形状如何影响输出张量的形状。如果我们在 indices 参数中使用了 rank-2 张量:

result = tf.gather_nd(x, [[0, 0], [0, 1], [1, 0], [1, 1]])

输出本来是

[[0 1]
 [2 3]
 [6 7]
 [8 9]]