pytorch不会保存加载的预训练模型权重及其部分内容到最终模型中

如何解决pytorch不会保存加载的预训练模型权重及其部分内容到最终模型中

我目前正在根据数据在CIFAR-10上进行预训练的模型,删除了模型的最后fc层,并附加了我自己的fc层和softmax。有七个网络,每个网络都与预训练部分相同,并使用附加的fc层进行组合。以下是经过预先训练的网络代码:

class Bottleneck(nn.Module):
    def __init__(self,inplanes,expansion=4,growthRate=12,dropRate=0):
        super(Bottleneck,self).__init__()
        planes = expansion * growthRate
        self.bn1 = nn.BatchNorm2d(inplanes)
        self.conv1 = nn.Conv2d(inplanes,planes,kernel_size=1,bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes,growthRate,kernel_size=3,padding=1,bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.dropRate = dropRate

        
    def forward(self,x):
        out = self.bn1(x)
        out = self.relu(out)
        out = self.conv1(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv2(out)
        if self.dropRate > 0:
            out = F.dropout(out,p=self.dropRate,training=self.training)

        out = torch.cat((x,out),1)

        return out


class BasicBlock(nn.Module):
    def __init__(self,expansion=1,dropRate=0):
        super(BasicBlock,x):
        out = self.bn1(x)
        out = self.relu(out)
        out = self.conv1(out)
        if self.dropRate > 0:
            out = F.dropout(out,1)

        return out


class Transition(nn.Module):
    def __init__(self,outplanes):
        super(Transition,self).__init__()
        self.bn1 = nn.BatchNorm2d(inplanes)
        self.conv1 = nn.Conv2d(inplanes,outplanes,bias=False)
        self.relu = nn.ReLU(inplace=True)

        
    def forward(self,x):
        out = self.bn1(x)
        out = self.relu(out)
        out = self.conv1(out)
        out = F.avg_pool2d(out,2)
        return out


class DenseNet(nn.Module):

    def __init__(self,depth = 22,block = Bottleneck,dropRate = 0,num_classes = 10,growthRate = 12,compressionRate = 2):
        super(DenseNet,self).__init__()

        assert (depth - 4) % 3 == 0,'depth should be 3n+4'
        n = (depth - 4) / 3 if block == BasicBlock else (depth - 4) // 6

        self.growthRate = growthRate
        self.dropRate = dropRate

        # self.inplanes is a global variable used across multiple
        # helper functions
        self.inplanes = growthRate * 2 
        self.conv1 = nn.Conv2d(3,self.inplanes,kernel_size = 3,padding = 1,bias = False)
        self.dense1 = self._make_denseblock(block,n)
        self.trans1 = self._make_transition(compressionRate)
        self.dense2 = self._make_denseblock(block,n)
        self.trans2 = self._make_transition(compressionRate)
        self.dense3 = self._make_denseblock(block,n)
        self.bn = nn.BatchNorm2d(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.avgpool = nn.AvgPool2d(8)
        #self.fc = nn.Linear(self.inplanes,num_classes)

        # Weight initialization
#         for m in self.modules():
#             if isinstance(m,nn.Conv2d):
#                 n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
#                 m.weight.data.normal_(0,math.sqrt(2. / n))
#             elif isinstance(m,nn.BatchNorm2d):
#                 m.weight.data.fill_(1)
#                 m.bias.data.zero_()


    def _make_denseblock(self,block,blocks):
        layers = []
        for i in range(blocks):
            # Currently we fix the expansion ratio as the default value
            layers.append(block(self.inplanes,growthRate = self.growthRate,dropRate=self.dropRate))
            self.inplanes += self.growthRate

        return nn.Sequential(*layers)

    def _make_transition(self,compressionRate):
        inplanes = self.inplanes
        outplanes = int(math.floor(self.inplanes // compressionRate))
        self.inplanes = outplanes
        return Transition(inplanes,outplanes)


    def forward(self,x):
        x = self.conv1(x)

        x = self.trans1(self.dense1(x)) 
        x = self.trans2(self.dense2(x)) 
        x = self.dense3(x)
        x = self.bn(x)
        x = self.relu(x)

        x = self.avgpool(x)
        #x = x.view(x.size(0),-1)
        #x = self.fc(x)

        return x
    
    
    def getParams(self,paramName):
        if paramName == 'inplanes':
            return self.inplanes
        elif paramName == 'growthRate':
            return self.growthRate
        elif paramName == 'dropRate':
            return self.dropRate
        
def densenet(**kwargs):
    """
    Constructs a DenseNet model.
    """
    return DenseNet(**kwargs) 

下面是我的代码:

class Network(nn.Module):
    
    def __init__(self,pretrained_dict,num_classes = 6,num_channels = 7,expansion = 4,depth = 100,dropRate = 0):
        
        super(Network,self).__init__()
        
        self.num_channels = num_channels
        
        # creating 7 channels networks 
        self.channels_dnsnets = []
        
        for ch in range(self.num_channels):
#             print(ch)
            
            d = densenet(depth = depth)
            d_dict = d.state_dict()
            
            # 1. filter out unnecessary keys
            pretrained_dict2 = {k[7:]: v for k,v in pretrained_dict.items() if k[7:] in d_dict}
#             print('d_dict_keys :')
#             print(d_dict.keys())
#             print('*'*50)
#             print('pretrained_dict2.keys:')
#             print(pretrained_dict2.keys())
#             print('*'*50)
            
            # 2. overwrite entries in the existing state dict
            d_dict.update(pretrained_dict2) 
            
            # 3. load the new state dict
            d.load_state_dict(pretrained_dict2)
            
            # freeze the layers of densenet
            for param in d.parameters():
                param.requires_grad = False
                
            self.channels_dnsnets.append(d)
            
        self.inplanes = self.channels_dnsnets[0].getParams(paramName = 'inplanes')
        self.fc = nn.Linear(self.inplanes * self.num_channels,num_classes)
        self.softmax = nn.Softmax(dim = 1)
        
        
    def forward(self,x):
        
        batch_size,channels,ht,wd,in_channels = x.shape
        x = np.reshape(x,(batch_size,in_channels,wd))

        out = []
    
        for num in range(self.num_channels):
            temp_out = self.channels_dnsnets[0](x[:,num,:])
            temp_out = temp_out.view(temp_out.size(0),-1)
#             print(temp_out.shape)
#             print('*' * 50)
            out.append(temp_out)
        
        out = torch.stack(out,dim = 1)
#         print(out.shape)
        out = out.view(out.size(0),-1)
        out = self.fc(out)
        out = self.softmax(out)
        return out 

我将优化器设置为:

optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,model.parameters()),lr = lr,betas = (0.9,0.999),eps = 1e-08,weight_decay = wd,amsgrad = False)
        

但是,每当我保存模型时,密集网列表及其权重都不会保存,而只会保存fc层和softmax层权重。代码有什么问题吗?我是pytorch的新手。

解决方法

问题是self.channels_dnsnets只是list,不会成为state_dict的一部分。仅self.fcself.softmax将被注册到Module中。最简单的更改就是这样定义:

self.channels_dnsnets = nn.ModuleList()

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


依赖报错 idea导入项目后依赖报错,解决方案:https://blog.csdn.net/weixin_42420249/article/details/81191861 依赖版本报错:更换其他版本 无法下载依赖可参考:https://blog.csdn.net/weixin_42628809/a
错误1:代码生成器依赖和mybatis依赖冲突 启动项目时报错如下 2021-12-03 13:33:33.927 ERROR 7228 [ main] o.s.b.d.LoggingFailureAnalysisReporter : *************************** APPL
错误1:gradle项目控制台输出为乱码 # 解决方案:https://blog.csdn.net/weixin_43501566/article/details/112482302 # 在gradle-wrapper.properties 添加以下内容 org.gradle.jvmargs=-Df
错误还原:在查询的过程中,传入的workType为0时,该条件不起作用 <select id="xxx"> SELECT di.id, di.name, di.work_type, di.updated... <where> <if test=&qu
报错如下,gcc版本太低 ^ server.c:5346:31: 错误:‘struct redisServer’没有名为‘server_cpulist’的成员 redisSetCpuAffinity(server.server_cpulist); ^ server.c: 在函数‘hasActiveC
解决方案1 1、改项目中.idea/workspace.xml配置文件,增加dynamic.classpath参数 2、搜索PropertiesComponent,添加如下 <property name="dynamic.classpath" value="tru
删除根组件app.vue中的默认代码后报错:Module Error (from ./node_modules/eslint-loader/index.js): 解决方案:关闭ESlint代码检测,在项目根目录创建vue.config.js,在文件中添加 module.exports = { lin
查看spark默认的python版本 [root@master day27]# pyspark /home/software/spark-2.3.4-bin-hadoop2.7/conf/spark-env.sh: line 2: /usr/local/hadoop/bin/hadoop: No s
使用本地python环境可以成功执行 import pandas as pd import matplotlib.pyplot as plt # 设置字体 plt.rcParams['font.sans-serif'] = ['SimHei'] # 能正确显示负号 p
错误1:Request method ‘DELETE‘ not supported 错误还原:controller层有一个接口,访问该接口时报错:Request method ‘DELETE‘ not supported 错误原因:没有接收到前端传入的参数,修改为如下 参考 错误2:cannot r
错误1:启动docker镜像时报错:Error response from daemon: driver failed programming external connectivity on endpoint quirky_allen 解决方法:重启docker -> systemctl r
错误1:private field ‘xxx‘ is never assigned 按Altʾnter快捷键,选择第2项 参考:https://blog.csdn.net/shi_hong_fei_hei/article/details/88814070 错误2:启动时报错,不能找到主启动类 #
报错如下,通过源不能下载,最后警告pip需升级版本 Requirement already satisfied: pip in c:\users\ychen\appdata\local\programs\python\python310\lib\site-packages (22.0.4) Coll
错误1:maven打包报错 错误还原:使用maven打包项目时报错如下 [ERROR] Failed to execute goal org.apache.maven.plugins:maven-resources-plugin:3.2.0:resources (default-resources)
错误1:服务调用时报错 服务消费者模块assess通过openFeign调用服务提供者模块hires 如下为服务提供者模块hires的控制层接口 @RestController @RequestMapping("/hires") public class FeignControl
错误1:运行项目后报如下错误 解决方案 报错2:Failed to execute goal org.apache.maven.plugins:maven-compiler-plugin:3.8.1:compile (default-compile) on project sb 解决方案:在pom.
参考 错误原因 过滤器或拦截器在生效时,redisTemplate还没有注入 解决方案:在注入容器时就生效 @Component //项目运行时就注入Spring容器 public class RedisBean { @Resource private RedisTemplate<String
使用vite构建项目报错 C:\Users\ychen\work>npm init @vitejs/app @vitejs/create-app is deprecated, use npm init vite instead C:\Users\ychen\AppData\Local\npm-