如何解决通过在分类之前添加其他信息来自定义CNN pytorch
我正在尝试使用Pytorch创建自定义的CNN架构。当前的体系结构用于文本多标签分类,但是我想将该文本的类别添加为CNN的附加输入,以帮助它记住文本来自哪个父类别。我想添加一个包含所有类别的热门向量。
我当前的代码:
class CNN(nn.Module):
"""
Convolutional Neural Model used for training the models. The total number of kernels that will be used in this
CNN is Co * len(Ks).
Args:
weights_matrix: numpy.ndarray,the shape of this n-dimensional array must be (words,dims) were words is
the number of words in the vocabulary and dims is the dimensionality of the word embeddings.
Co (number of filters): integer,stands for channels out and it is the number of kernels of the same size that will be used.
Hu: integer,stands for number of hidden units in the hidden layer.
C: integer,number of units in the last layer (number of classes)
Ks: list,list of integers specifying the size of the kernels to be used.
"""
def __init__(self,vocab_size,emb_dim,Co,Hu,C,Ks,name = 'generic'):
super(CNN,self).__init__()
self.num_embeddings = vocab_size
self.embeddings_dim = emb_dim
self.padding_index = 0
self.cnn_name = 'cnn_' + str(emb_dim) + '_' + str(Co) + '_' + str(Hu) + '_' + str(C) + '_' + str(Ks) + '_' + name
self.Co = Co
self.Hu = Hu
self.C = C
self.Ks = Ks
self.embedding = nn.Embedding(self.num_embeddings,self.embeddings_dim,self.padding_index)
self.convolutions = nn.ModuleList([nn.Conv2d(1,self.Co,(k,self.embeddings_dim)) for k in self.Ks])
self.relu = nn.ReLU()
self.drop_out = nn.Dropout(p=0.5)
units = [self.Co * len(self.Ks)] + Hu
self.linear_layers = nn.ModuleList([nn.Linear(units[k],units[k+1]) for k in range(len(units)-1)])
self.linear_last = nn.Linear(self.Hu[-1],self.C)
self.sigmoid = nn.Sigmoid()
def forward(self,x):
x = self.embedding(x)
x = [self.relu(conv(x)).squeeze(3) for conv in self.convolutions]
x = [F.max_pool1d(i,i.size(2)).squeeze(2) for i in x]
x = torch.cat(x,1)
x = linear(x)
x = self.relu(x)
x = self.drop_out(x)
x = self.linear_last(x)
x = self.sigmoid(x)
return x
我想添加一个线性层,该层具有一个热向量作为输入并将该层连接到我的神经网络(将CNN的输出与新层连接起来),并且AFAIK PyTorch自己进行反向传播。
我是Pytorch的新手,因此,如果您可以对我进行些修改,或者对任何有用的方向提出建议,都可以帮助我。谢谢
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。