如何解决Python Numpy从火车数据集中过滤标签零的简单方法
我有一个火车数据集trainX和trainy。 TrainX的形状为(n,128,3)火车的形状为(n,1)。
我想删除所有训练标签为零的样本。我有以下代码可以正常工作,但是一些numpy数组的尺寸丢失了。
# load the dataset,returns train and test X and y elements
def load_dataset_train_only(prefix='trainingData128/'):
# load all train
trainX,trainy = load_dataset_group('train',prefix)
print(trainX.shape,trainy.shape)
# zero-offset class values
trainy = trainy
#remove samples where stroke rate = 0
train_filter = np.where(trainy != 0)
trainX,trainy = trainX[train_filter],trainy[train_filter]
print('Size after filtering off zero labels',trainX.shape,trainy.shape)
# one hot encode y
trainy = to_categorical(trainy)
print(trainX.shape,trainy.shape)
return trainX,trainy
过滤前的形状是; (366511,128,3)(366511,1)
过滤后的形状是 滤除零标签后的大小(280905,3)(280905,)
280905值正确,因此它正在过滤零标签,但是如何修改代码以免丢失尺寸?
谢谢
解决方法
对于n-D
数组,np.where
返回所有n
维的索引,因此:
>>> arr = np.arange(12).reshape(3,4)
>>> arr
array([[ 0,1,2,3],[ 4,5,6,7],[ 8,9,10,11]])
>>> np.where(arr > 3)
(array([1,2],dtype=int64),array([0,3,dtype=int64))
因此,只需要第一维,就应该使用:
train_filter = np.where(trainy != 0)[0]
,
您可以使用np.nonzero()
过滤掉不为零的元素。
import numpy as np
x = np.random.randint(0,(10,1))
array([[0],[2],[0],[1],[2]])
not_zero,_ = np.nonzero(x)
array([1,7,9],dtype=int64)
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。