如何解决过滤Numpy的数组数组
使用numpy的ndarray将数据预处理到神经网络。它基本上包含用于传感器数据的几个固定长度的数组。例如:
>>> type(arr)
<class 'numpy.ndarray'>
>>> arr.shape
(400,1,5,4)
>>> arr
[
[[ 9.4 -3.7 -5.2 3.8]
[ 2.8 1.4 -1.7 3.4]
[ 0.0 0.0 0.0 0.0]
[ 0.0 0.0 0.0 0.0]
[ 0.0 0.0 0.0 0.0]]
..
[[ 0.0 -1.0 2.1 0.0]
[ 3.0 2.8 -3.0 8.2]
[ 7.5 1.7 -3.8 2.6]
[ 0.0 0.0 0.0 0.0]
[ 0.0 0.0 0.0 0.0]]
]
每个嵌套数组的形状为(1,4)
。目标是遍历此arr
并仅将至少具有前三行的那些数组选择为非零(尽管单个条目可以为零,但不能整行)。
因此,在上面给出的示例中,应该删除第一个嵌套数组,因为只有2个第一行非零,而我们需要3个及以上。
解决方法
这是您可以使用的技巧:
mask = arr[:,:,:3].any(axis=3).all(axis=2)
arr_filtered = arr[mask]
快速说明:要保留一个嵌套数组,它应至少有3个第一行(因此我们只需要查看arr[:,:3]
),以使所有它们(因此.all(axis=2)
结尾)都具有至少一个非零条目(因此.any(axis=3)
)。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。