如何使用 tf.gather nd

tf.gather_nd 是擴充套件 tf.gather 的,因為它可以讓你不僅可以訪問一個張量的 1 維的感覺,但可能所有的人。

引數:

  • params:一個等級 P 的張量,代表我們想要索引的張量
  • indices:一個等級 Q 的張量,代表我們想要訪問的 params 的索引

功能的輸出取決於 indices 的形狀。如果 indices 的最內層尺寸為 P,我們正在從 params 收集單個元素。如果它小於 P,我們正在收集切片,就像 tf.gather 一樣但沒有限制我們只能訪問第一維。

從等級 2 的張量中收集元素

要在矩陣中訪問 (1, 2) 中的元素,我們可以使用:

# data is [[0, 1, 2, 3, 4, 5],
#          [6, 7, 8, 9, 10, 11],
#          [12 13 14 15 16 17],
#          [18 19 20 21 22 23],
#          [24, 25, 26, 27, 28, 29]]
data = np.reshape(np.arange(30), [5, 6])
x = tf.constant(data)
result = tf.gather_nd(x, [1, 2])

result 將如預期的那樣成為 8。請注意這與 tf.gather 有何不同:傳遞給 tf.gather(x, [1, 2]) 的相同索引將作為 data 的第 2 和第 3 給出。

如果要同時檢索多個元素,只需傳遞一個索引對列表:

result = tf.gather_nd(x, [[1, 2], [4, 3], [2, 5]])

這將返回 [ 8 27 17]

從等級 2 的張量中收集行

如果在上面的示例中你想要收集行(即切片)而不是元素,請按如下方式調整 indices 引數:

data = np.reshape(np.arange(30), [5, 6])
x = tf.constant(data)
result = tf.gather_nd(x, [[1], [3]])

這將給你第 2 和第 4 行 data,即

[[ 6  7  8  9 10 11]
 [18 19 20 21 22 23]]

從第 3 級的張量中收集元素

如何訪問秩 -2 張量的概念直接轉換為更高維度的張量。因此,要訪問 rank-3 張量中的元素,indices 的最內層維度必須為 3。

# data is [[[ 0  1]
#          [ 2  3]
#          [ 4  5]]
#
#         [[ 6  7]
#          [ 8  9]
#          [10 11]]]
data = np.reshape(np.arange(12), [2, 3, 2])
x = tf.constant(data)
result = tf.gather_nd(x, [[0, 0, 0], [1, 2, 1]])

result 現在看起來像這樣:[ 0 11]

從 3 級的張量中收集批量行

讓我們把秩 -3 張量想象成一批形狀為 tihuan 的矩陣 26。如果要為批處理中的每個元素收集第一行和第二行,可以使用:

# data is [[[ 0  1]
#          [ 2  3]
#          [ 4  5]]
#
#         [[ 6  7]
#          [ 8  9]
#          [10 11]]]
data = np.reshape(np.arange(12), [2, 3, 2])
x = tf.constant(data)
result = tf.gather_nd(x, [[[0, 0], [0, 1]], [[1, 0], [1, 1]]])

這將導致:

[[[0 1]
  [2 3]]

 [[6 7]
  [8 9]]]

注意 indices 的形狀如何影響輸出張量的形狀。如果我們在 indices 引數中使用了 rank-2 張量:

result = tf.gather_nd(x, [[0, 0], [0, 1], [1, 0], [1, 1]])

輸出本來是

[[0 1]
 [2 3]
 [6 7]
 [8 9]]