类型错误:forward() 缺少 1 个必需的位置参数:'c'

如何解决类型错误:forward() 缺少 1 个必需的位置参数:'c'

我创建了这个简化版的 VGG16:

class VGG16COMBO(nn.Module):
    
    def __init__(self,num_classes):
        super(VGG16COMBO,self).__init__()

        # calculate same padding:
        # (w - k + 2*p)/s + 1 = o
        # => p = (s(o-1) - w + k)/2

        self.block_1 = nn.Sequential(
            nn.Conv2d(in_channels=1,out_channels=64,kernel_size=(3,3),stride=(1,1),# (1(32-1)- 32 + 3)/2 = 1
                      padding=1),nn.BatchNorm2d(64),nn.ReLU(),nn.Conv2d(in_channels=64,padding=1),nn.MaxPool2d(kernel_size=(2,2),stride=(2,2))
        )

        self.block_2 = nn.Sequential(
            nn.Conv2d(in_channels=64,out_channels=128,nn.BatchNorm2d(128),nn.Conv2d(in_channels=128,2))
        )
        
        self.block_3 = nn.Sequential(
            nn.Conv2d(in_channels=128,out_channels=256,nn.BatchNorm2d(256),nn.Conv2d(in_channels=256,2))
        )

        self.block_4 = nn.Sequential(
            nn.Conv2d(in_channels=256,out_channels=512,nn.BatchNorm2d(512),nn.Conv2d(in_channels=512,2))
        ) 


        self.classifier = nn.Sequential(
            nn.Linear(2048,4096),nn.ReLU(True),nn.Dropout(p=0.25),nn.Linear(4096,num_classes),)

    def forward(self,m,c):

        m = self.block_1(m)
        m = self.block_2(m)
        m = self.block_3(m)
        m = self.block_4(m)
        m = m.view(m.size(0),-1)
        m = self.classifier(m)

        c = self.block_1(c)
        c = self.block_2(c)
        c = self.block_3(c)
        c = self.block_4(c)
        c = c.view(c.size(0),-1)
        c = self.classifier(c)

        x = torch.cat((m,c),dim=1)
        return x

你可以看到在forward中我传递了2个元素,m和c。 m 指的是 MNIST,c 指的是 CIFAR10,因为我想要一个多输入神经网络(或具有共享权重的网络)。 然后:

modelcombo = VGG16COMBO(1).cuda()
print(modelcombo)

# Define an optimizier
import torch.optim as optim
optimizer = optim.SGD(modelcombo.parameters(),lr = 0.01)
# Define a loss 
criterion = nn.BCEWithLogitsLoss()

这是我的训练函数:

#train da modificare con entrambi i dataset
def train(net,loaders,optimizer,criterion,epochs=20,dev=dev,save_param = False,model_name="valerio"):
    try:
        net = net.to(dev)
        #print(net)
        # Initialize history
        history_loss = {"train": [],"val": [],"test": []}
        history_accuracy = {"train": [],"test": []}
        # Store the best val accuracy
        best_val_accuracy = 0

        # Process each epoch
        for epoch in range(epochs):
            # Initialize epoch variables
            sum_loss = {"train": 0,"val": 0,"test": 0}
            sum_accuracy = {"train": 0,"test": 0}
            # Process each split
            for split in ["train","val","test"]:
                if split == "train":
                  net.train()
                else:
                  net.eval()
                # Process each batch
                for (input,labels) in loaders[split]:
                    # Move to CUDA
                    input = input.to(dev)
                    labels = labels.to(dev)
                    # Reset gradients
                    optimizer.zero_grad()
                    # Compute output
                    pred = net(input)
                    #pred = pred.squeeze(dim=1) # Output shape is [Batch size,1],but we want [Batch size]
                    labels = labels.unsqueeze(1)
                    labels = labels.float()
                    loss = criterion(pred,labels)
                    # Update loss
                    sum_loss[split] += loss.item()
                    # Check parameter update
                    if split == "train":
                        # Compute gradients
                        loss.backward()
                        # Optimize
                        optimizer.step()
                    # Compute accuracy
                    #pred_labels = pred.argmax(1) + 1
                    pred_labels = (pred >= 0.5).long() # Binarize predictions to 0 and 1
                    batch_accuracy = (pred_labels == labels).sum().item()/input.size(0)
                    # Update accuracy
                    sum_accuracy[split] += batch_accuracy
            # Compute epoch loss/accuracy
            epoch_loss = {split: sum_loss[split]/len(loaders[split]) for split in ["train","test"]}
            epoch_accuracy = {split: sum_accuracy[split]/len(loaders[split]) for split in ["train","test"]}

            # Store params at the best validation accuracy
            if save_param and epoch_accuracy["val"] > best_val_accuracy:
              #torch.save(net.state_dict(),f"{net.__class__.__name__}_best_val.pth")
              torch.save(net.state_dict(),f"{model_name}_best_val.pth")
              best_val_accuracy = epoch_accuracy["val"]

            # Update history
            for split in ["train","test"]:
                history_loss[split].append(epoch_loss[split])
                history_accuracy[split].append(epoch_accuracy[split])
            # Print info
            print(f"Epoch {epoch+1}:",f"TrL={epoch_loss['train']:.4f},",f"TrA={epoch_accuracy['train']:.4f},f"VL={epoch_loss['val']:.4f},f"VA={epoch_accuracy['val']:.4f},f"TeL={epoch_loss['test']:.4f},f"TeA={epoch_accuracy['test']:.4f},")
    except KeyboardInterrupt:
        print("Interrupted")
    finally:
        # Plot loss
        plt.title("Loss")
        for split in ["train","test"]:
            plt.plot(history_loss[split],label=split)
        plt.legend()
        plt.show()
        # Plot accuracy
        plt.title("Accuracy")
        for split in ["train","test"]:
            plt.plot(history_accuracy[split],label=split)
        plt.legend()
        plt.show()

但是当我进行训练时

# Train model
train(modelcombo,epochs=10,dev=dev)

我收到此错误:

TypeError: forward() missing 1 required positional argument: 'c'

我必须改变什么,网络还是训练功能?我认为问题出在训练功能上,因为我必须通过loaders和loaders_cifar,但我不知道如何。特别是,在将 mnist 加载器和 cifar 加载器传递给训练函数之前,我必须对它们进行分类,或者我必须在 for (input,labels) in loaders[split]: 之类的东西中修改 for (input,labels) in loaders[split] and loaders_cifar[split]:

编辑:我创建了这个函数:

def itr_merge(*itrs):
    for itr in itrs:
        for v in itr:
            yield v

这样编辑训练函数:

#train da modificare con entrambi i dataset
def train2(net,loaders_cifar,"test"]:
                if split == "train":
                  net.train()
                else:
                  net.eval()
                # Process each batch
                for x in itr_merge(loaders[split],loaders_cifar[split]):
                  for (input,labels) in loaders[split]:
                      # Move to CUDA
                      input = input.to(dev)
                      labels = labels.to(dev)
                      # Reset gradients
                      optimizer.zero_grad()
                      # Compute output
                      pred = net(input)
                      #pred = pred.squeeze(dim=1) # Output shape is [Batch size,but we want [Batch size]
                      labels = labels.unsqueeze(1)
                      labels = labels.float()
                      loss = criterion(pred,labels)
                      # Update loss
                      sum_loss[split] += loss.item()
                      # Check parameter update
                      if split == "train":
                          # Compute gradients
                          loss.backward()
                          # Optimize
                          optimizer.step()
                      # Compute accuracy
                      #pred_labels = pred.argmax(1) + 1
                      pred_labels = (pred >= 0.5).long() # Binarize predictions to 0 and 1
                      batch_accuracy = (pred_labels == labels).sum().item()/input.size(0)
                      # Update accuracy
                      sum_accuracy[split] += batch_accuracy
            # Compute epoch loss/accuracy
            epoch_loss = {split: sum_loss[split]/len(loaders[split]) for split in ["train",label=split)
        plt.legend()
        plt.show()

但我仍然有同样的错误

解决方法

是的,如果您有 2 个数据点输入,则在此处传递 2 个参数

pred = net(input1,input2) #input1 ---> mnist,input2 ---> cifar

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


依赖报错 idea导入项目后依赖报错,解决方案:https://blog.csdn.net/weixin_42420249/article/details/81191861 依赖版本报错:更换其他版本 无法下载依赖可参考:https://blog.csdn.net/weixin_42628809/a
错误1:代码生成器依赖和mybatis依赖冲突 启动项目时报错如下 2021-12-03 13:33:33.927 ERROR 7228 [ main] o.s.b.d.LoggingFailureAnalysisReporter : *************************** APPL
错误1:gradle项目控制台输出为乱码 # 解决方案:https://blog.csdn.net/weixin_43501566/article/details/112482302 # 在gradle-wrapper.properties 添加以下内容 org.gradle.jvmargs=-Df
错误还原:在查询的过程中,传入的workType为0时,该条件不起作用 <select id="xxx"> SELECT di.id, di.name, di.work_type, di.updated... <where> <if test=&qu
报错如下,gcc版本太低 ^ server.c:5346:31: 错误:‘struct redisServer’没有名为‘server_cpulist’的成员 redisSetCpuAffinity(server.server_cpulist); ^ server.c: 在函数‘hasActiveC
解决方案1 1、改项目中.idea/workspace.xml配置文件,增加dynamic.classpath参数 2、搜索PropertiesComponent,添加如下 <property name="dynamic.classpath" value="tru
删除根组件app.vue中的默认代码后报错:Module Error (from ./node_modules/eslint-loader/index.js): 解决方案:关闭ESlint代码检测,在项目根目录创建vue.config.js,在文件中添加 module.exports = { lin
查看spark默认的python版本 [root@master day27]# pyspark /home/software/spark-2.3.4-bin-hadoop2.7/conf/spark-env.sh: line 2: /usr/local/hadoop/bin/hadoop: No s
使用本地python环境可以成功执行 import pandas as pd import matplotlib.pyplot as plt # 设置字体 plt.rcParams['font.sans-serif'] = ['SimHei'] # 能正确显示负号 p
错误1:Request method ‘DELETE‘ not supported 错误还原:controller层有一个接口,访问该接口时报错:Request method ‘DELETE‘ not supported 错误原因:没有接收到前端传入的参数,修改为如下 参考 错误2:cannot r
错误1:启动docker镜像时报错:Error response from daemon: driver failed programming external connectivity on endpoint quirky_allen 解决方法:重启docker -> systemctl r
错误1:private field ‘xxx‘ is never assigned 按Altʾnter快捷键,选择第2项 参考:https://blog.csdn.net/shi_hong_fei_hei/article/details/88814070 错误2:启动时报错,不能找到主启动类 #
报错如下,通过源不能下载,最后警告pip需升级版本 Requirement already satisfied: pip in c:\users\ychen\appdata\local\programs\python\python310\lib\site-packages (22.0.4) Coll
错误1:maven打包报错 错误还原:使用maven打包项目时报错如下 [ERROR] Failed to execute goal org.apache.maven.plugins:maven-resources-plugin:3.2.0:resources (default-resources)
错误1:服务调用时报错 服务消费者模块assess通过openFeign调用服务提供者模块hires 如下为服务提供者模块hires的控制层接口 @RestController @RequestMapping("/hires") public class FeignControl
错误1:运行项目后报如下错误 解决方案 报错2:Failed to execute goal org.apache.maven.plugins:maven-compiler-plugin:3.8.1:compile (default-compile) on project sb 解决方案:在pom.
参考 错误原因 过滤器或拦截器在生效时,redisTemplate还没有注入 解决方案:在注入容器时就生效 @Component //项目运行时就注入Spring容器 public class RedisBean { @Resource private RedisTemplate<String
使用vite构建项目报错 C:\Users\ychen\work>npm init @vitejs/app @vitejs/create-app is deprecated, use npm init vite instead C:\Users\ychen\AppData\Local\npm-