如何解决在PyTorch中提取CNN的中间层输出
我正在使用Resnet18
模型。
ResNet(
(conv1): Conv2d(3,64,kernel_size=(7,7),stride=(2,2),padding=(3,3),bias=False)
(bn1): BatchNorm2d(64,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3,stride=2,padding=1,dilation=1,ceil_mode=False)
(layer1): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64,kernel_size=(3,stride=(1,1),padding=(1,bias=False)
(bn1): BatchNorm2d(64,track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64,bias=False)
(bn2): BatchNorm2d(64,track_running_stats=True)
)
(1): BasicBlock(
(conv1): Conv2d(64,track_running_stats=True)
)
)
(layer2): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64,128,bias=False)
(bn1): BatchNorm2d(128,track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128,bias=False)
(bn2): BatchNorm2d(128,track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(64,kernel_size=(1,bias=False)
(1): BatchNorm2d(128,track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(128,track_running_stats=True)
)
)
(layer3): Sequential(
(0): BasicBlock(
(conv1): Conv2d(128,256,bias=False)
(bn1): BatchNorm2d(256,track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256,bias=False)
(bn2): BatchNorm2d(256,track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(128,bias=False)
(1): BatchNorm2d(256,track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(256,track_running_stats=True)
)
)
(layer4): Sequential(
(0): BasicBlock(
(conv1): Conv2d(256,512,bias=False)
(bn1): BatchNorm2d(512,track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512,bias=False)
(bn2): BatchNorm2d(512,track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(256,bias=False)
(1): BatchNorm2d(512,track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(512,track_running_stats=True)
)
)
(avgpool): AdaptiveAvgPool2d(output_size=(1,1))
(fc): Linear(in_features=512,out_features=1000,bias=True)
)
我只想从layer2
,layer3
,layer4
提取输出,而我不希望avgpool
和fc
的输出。
我该如何实现?
class BasicBlock(nn.Module):
def __init__(self,in_channels,out_channels,stride=1,padding=1) -> None:
super(BasicBlock,self).__init__()
self.conv1 = nn.Conv2d(in_channels,3,stride,padding=padding,bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels,bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
if in_channels != out_channels:
l1 = nn.Conv2d(in_channels,kernel_size=1,stride=stride,bias=False)
l2 = nn.BatchNorm2d(out_channels)
self.downsample = nn.Sequential(l1,l2)
else:
self.downsample = None
def forward(self,xb):
prev = xb
x = self.relu(self.bn1(self.conv1(xb)))
x = self.bn2(self.conv2(x))
if self.downsample is not None:
prev = self.downsample(xb)
x = x + prev
return self.relu(x)
class CustomResnet(nn.Module):
def __init__(self,pretrained:bool=True) -> None:
super(CustomResnet,self).__init__()
self.conv1 = nn.Conv2d(3,kernel_size=7,padding=3,bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3,padding=1)
self.layer1 = nn.Sequential(BasicBlock( 64,stride=1),BasicBlock(64,64))
self.layer2 = nn.Sequential(BasicBlock(64,stride=2),BasicBlock(128,128))
self.layer3 = nn.Sequential(BasicBlock(128,BasicBlock(256,256))
self.layer4 = nn.Sequential(BasicBlock(256,BasicBlock(512,512))
def forward(self,xb):
x = self.maxpool(self.relu(self.bn1(self.conv1(xb))))
x = self.layer1(x)
x2 = x = self.layer2(x)
x3 = x = self.layer3(x)
x4 = x = self.layer4(x)
return [x2,x3,x4]
我想一种解决方案是..但是,如果在编写大量代码时没有编写此代码,还有其他方法吗?在上述经过修改的torchvision
模型中,也可以加载ResNet
给定的预训练权重。
解决方法
如果您知道如何实现forward
方法,则可以对模型进行子类化,并仅覆盖forward
方法。
如果您在PyTorch中使用模型的预训练权重,则您已经可以访问模型的代码。因此,找到模型代码的位置,将其导入,对该模型进行子类化,并覆盖forward
方法。
例如:
class MyResNet18(Resnet):
def __init__(self,*args,**kwargs):
super().__init__(*args,**kwargs)
def forward(self,xb):
x = self.maxpool(self.relu(self.bn1(self.conv1(xb))))
x = self.layer1(x)
x2 = x = self.layer2(x)
x3 = x = self.layer3(x)
x4 = x = self.layer4(x)
return [x2,x3,x4]
您已经完成。
,为了将来参考,有一个pytorch实用程序可以轻松获得中间结果https://pypi.org/project/torch-intermediate-layer-getter/
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。