從張量的第一維提取非連續切片

通常,tf.gather 允許你訪問張量的第一維中的元素(例如,二維張量中的行 1,3 和 7)。如果你需要訪問除第一個之外的任何其他維度,或者如果你不需要整個切片,但是例如僅在第 1 行,第 3 行和第 7 行中的第 5 個條目,則最好使用 tf.gather_nd(請參閱即將發​​布的示例這個)。

tf.gather 論點:

  • params:要從中提取值的張量。
  • indices:張量,指定指向 params 的索引

有關詳細資訊 請參閱 tf.gather(params, indices) 文件。

我們想要在二維張量中提取第 1 和第 4 行。

# data is [[0, 1, 2, 3, 4, 5],
#          [6, 7, 8, 9, 10, 11],
#          ...
#          [24, 25, 26, 27, 28, 29]]
data = np.reshape(np.arange(30), [5, 6])
params = tf.constant(data)
indices = tf.constant([0, 3])
selected = tf.gather(params, indices)

selected 具有形狀 [2, 6] 並列印其值

[[ 0  1  2  3  4  5]
 [18 19 20 21 22 23]]

indices 也可以只是一個標量(但不能包含負數指數)。例如在上面的例子中:

tf.gather(params, tf.constant(3))

會列印

[18 19 20 21 22 23]

請注意,indices 可以具有任何形狀,但儲存在 indices 中的元素始終僅指 params第一維。例如,如果要同時檢索第 1 行第 3 行以及第 2 行第 4 行,則可以執行以下操作:

indices = tf.constant([[0, 2], [1, 3]])
selected = tf.gather(params, indices)

現在 selected 將形成 [2, 2, 6],其內容為:

[[[ 0  1  2  3  4  5]
  [12 13 14 15 16 17]]

 [[ 6  7  8  9 10 11]
  [18 19 20 21 22 23]]]

你可以使用 tf.gather 來計算排列。例如,以下內容會反轉 params 的所有行:

indices = tf.constant(list(range(4, -1, -1)))
selected = tf.gather(params, indices)

selected 現在

[[24 25 26 27 28 29]
 [18 19 20 21 22 23]
 [12 13 14 15 16 17]
 [ 6  7  8  9 10 11]
 [ 0  1  2  3  4  5]]

如果你需要訪問除第一個維度之外的任何其他維度,你可以使用 tf.transpose 解決這個問題:例如,在我們的示例中收集列而不是行,你可以這樣做:

indices = tf.constant([0, 2])
selected = tf.gather(tf.transpose(params, [1, 0]), indices)
selected_t = tf.transpose(selected, [1, 0]) 

selected_t 的形狀為 [5, 2],內容如下:

[[ 0  2]
 [ 6  8]
 [12 14]
 [18 20]
 [24 26]]

然而,tf.transpose 相當昂貴,所以在這個用例中使用 tf.gather_nd 可能會更好。