我有一个2darray如下.我想通过数组中的每一行找到高于阈值(例如0.7)的值的索引.
items= np.array([[1.,0.40824829,0.03210806,0.29488391,0.,0.5,0.32444284,0.57735027,0.5 ],[0.40824829,1.,0.57675476,0.48154341,0.81649658,0.79471941,0.70710678,0.40824829],[0.03210806,0.42606683,0.92713363,0.834192,0.73848549],[0.29488391,0.52620136,0.51075392,0.20851441,0.44232587],[0.,0. ],[0.5,0.28867513,[0.32444284,0.93658581,0.22941573,0.81110711],[0.57735027,0.8660254 ],0.73848549,0.44232587,0.81110711,0.8660254,1. ]])
indices_items = np.argwhere(items>= 0.7)
此(indices_items)返回
array([[0,0],[1,1],5],6],7],[2,2],9],[3,3],[5,8],[6,[7,[8,[9,9]],dtype=int64)
我怎样才能按行获取索引,如下所示?
第0行-> [0]行1- [0,1,5,6,7]行2-> [2,7,9]行3-> [3] row4-> []
#此列表应该为空,因为没有超出阈值的值…
最佳答案
获取带有np.where的行,然后使用np.searchsorted来获取行数组上的间隔索引,并使用它们来拆分col-array-
In [38]: r,c = np.where(items>= 0.7)
In [39]: np.split(c,np.searchsorted(r,range(1,items.shape[0])))
Out[39]:
[array([0],dtype=int64),array([1,array([2,array([3],array([],2,array([5,dtype=int64)]
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。