numpy的匹配索引尺寸
问题描述:
问题
我有两个numpy的阵列,A
和indices
。numpy的匹配索引尺寸
A
具有尺寸m x n x 10000. indices
具有尺寸m x n x 5(从argpartition(A, 5)[:,:,:5]
输出)。 我想得到一个m×n×5的数组,其中包含对应于indices
的A
的元素。
尝试
indices = np.array([[[5,4,3,2,1],[1,1,1,1,1],[1,1,1,1,1]],
[500,400,300,200,100],[100,100,100,100,100],[100,100,100,100,100]])
A = np.reshape(range(2 * 3 * 10000), (2,3,10000))
A[...,indices] # gives an array of size (2,3,2,3,5). I want a subset of these values
np.take(A, indices) # shape is right, but it flattens the array first
np.choose(indices, A) # fails because of shape mismatch.
动机
我试图得到A[i,j]
5个最大值为每i<m
,j<n
使用np.argpartition
因为阵列可以得到相当大的排序顺序。
答
您可以使用advanced-indexing
-
m,n = A.shape[:2]
out = A[np.arange(m)[:,None,None],np.arange(n)[:,None],indices]
采样运行 -
In [330]: A
Out[330]:
array([[[38, 21, 61, 74, 35, 29, 44, 46, 43, 38],
[22, 44, 89, 48, 97, 75, 50, 16, 28, 78],
[72, 90, 48, 88, 64, 30, 62, 89, 46, 20]],
[[81, 57, 18, 71, 43, 40, 57, 14, 89, 15],
[93, 47, 17, 24, 22, 87, 34, 29, 66, 20],
[95, 27, 76, 85, 52, 89, 69, 92, 14, 13]]])
In [331]: indices
Out[331]:
array([[[7, 8, 1],
[7, 4, 7],
[4, 8, 4]],
[[0, 7, 4],
[5, 3, 1],
[1, 4, 0]]])
In [332]: m,n = A.shape[:2]
In [333]: A[np.arange(m)[:,None,None],np.arange(n)[:,None],indices]
Out[333]:
array([[[46, 43, 21],
[16, 97, 16],
[64, 46, 64]],
[[81, 14, 43],
[87, 24, 47],
[27, 52, 95]]])
为了得到相对应的最大沿最后轴5种元素的索引,我们将使用argpartition
,像这样 -
indices = np.argpartition(-A,5,axis=-1)[...,:5]
为了保持订单从最高到最低,我们e range(5)
而不是5
。
答
为子孙后代,下面采用Divakar的答案来完成原来的目标,即在排序的顺序返回的前5名值的所有i<m, j<n
:
m, n = np.shape(A)[:2]
# get the largest 5 indices for all m, n
top_unsorted_indices = np.argpartition(A, -5, axis=2)[...,-5:]
# get the values corresponding to top_unsorted_indices
top_values = A[np.arange(m)[:,None,None], np.arange(n)[:,None], top_unsorted_indices]
# sort the top 5 values
top_sorted_indices = top_unsorted_indices[np.arange(m)[:,None,None], np.arange(n)[:,None], np.argsort(-top_values)]