如何解决pytorch二进制分类中如何处理不平衡类
我正在研究二进制文本分类问题。我该如何应用smote或WeightedRandomSample来解决数据集中的不平衡问题。我的代码目前看起来像这样:
class GDataset(Dataset):
def __init__(self,passage,targets,tokenizer,max_len):
self.passage = passage
self.targets = targets
self.tokenizer = tokenizer
self.max_len = max_len
def __len__(self):
return len(self.passage)
def __getitem__(self,item):
passage = str(self.passage[item])
target = self.targets[item]
if (target == 1) and self.transform: # minority class
x = self.transform(x)
encoding = self.tokenizer.encode_plus(
passage,add_special_tokens=True,max_length=self.max_len,return_token_type_ids=False,pad_to_max_length=True,return_attention_mask=True,return_tensors='pt',)
return
'passage_text': passage,'input_ids': encoding['input_ids'].flatten(),'attention_mask': encoding['attention_mask'].flatten(),'targets': torch.tensor(target,dtype=torch.long
我如何使用其他平衡技术?
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。