从张量的第一维提取非连续切片
通常,tf.gather
允许你访问张量的第一维中的元素(例如,二维张量中的行 1,3 和 7)。如果你需要访问除第一个之外的任何其他维度,或者如果你不需要整个切片,但是例如仅在第 1 行,第 3 行和第 7 行中的第 5 个条目,则最好使用 tf.gather_nd
(请参阅即将发布的示例这个)。
tf.gather
论点:
params
:要从中提取值的张量。indices
:张量,指定指向params
的索引
有关详细信息, 请参阅 tf.gather(params, indices) 文档。
我们想要在二维张量中提取第 1 和第 4 行。
# data is [[0, 1, 2, 3, 4, 5],
# [6, 7, 8, 9, 10, 11],
# ...
# [24, 25, 26, 27, 28, 29]]
data = np.reshape(np.arange(30), [5, 6])
params = tf.constant(data)
indices = tf.constant([0, 3])
selected = tf.gather(params, indices)
selected
具有形状 [2, 6]
并打印其值
[[ 0 1 2 3 4 5]
[18 19 20 21 22 23]]
indices
也可以只是一个标量(但不能包含负数指数)。例如在上面的例子中:
tf.gather(params, tf.constant(3))
会打印
[18 19 20 21 22 23]
请注意,indices
可以具有任何形状,但存储在 indices
中的元素始终仅指 params
的第一维。例如,如果要同时检索第 1 行和第 3 行以及第 2 行和第 4 行,则可以执行以下操作:
indices = tf.constant([[0, 2], [1, 3]])
selected = tf.gather(params, indices)
现在 selected
将形成 [2, 2, 6]
,其内容为:
[[[ 0 1 2 3 4 5]
[12 13 14 15 16 17]]
[[ 6 7 8 9 10 11]
[18 19 20 21 22 23]]]
你可以使用 tf.gather
来计算排列。例如,以下内容会反转 params
的所有行:
indices = tf.constant(list(range(4, -1, -1)))
selected = tf.gather(params, indices)
selected
现在
[[24 25 26 27 28 29]
[18 19 20 21 22 23]
[12 13 14 15 16 17]
[ 6 7 8 9 10 11]
[ 0 1 2 3 4 5]]
如果你需要访问除第一个维度之外的任何其他维度,你可以使用 tf.transpose
解决这个问题:例如,在我们的示例中收集列而不是行,你可以这样做:
indices = tf.constant([0, 2])
selected = tf.gather(tf.transpose(params, [1, 0]), indices)
selected_t = tf.transpose(selected, [1, 0])
selected_t
的形状为 [5, 2]
,内容如下:
[[ 0 2]
[ 6 8]
[12 14]
[18 20]
[24 26]]
然而,tf.transpose
相当昂贵,所以在这个用例中使用 tf.gather_nd
可能会更好。