Pytorch Argmax形状问题

如何解决Pytorch Argmax形状问题

import sys,os
import cv2
import numpy as np
from tqdm import tqdm
REBUILD_DATA = True
import matplotlib.pyplot as plt
class ArtOrNot():
f = open('Art2','w')
f.write('oof')
f.close()
IMG_SIZE = 50
ART = (r'C:\Users\Kyel\Desktop\Python projects\ART')
NOTART = (r'C:\Users\Kyel\Desktop\Python projects\NOTART')
LABELS = {ART: 1,NOTART: 0}
training_data = []
artcount = 0
notartcount = 0


def make_training_data(self):
    for label in self.LABELS:
        print(label)
        for f in tqdm(os.listdir(label)):
            try:
                path = os.path.join(label,f)
                img = cv2.imread(path)
                img = cv2.resize(img,(self.IMG_SIZE,self.IMG_SIZE))
                #gray = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
                #(img) = (grey)
                self.training_data.append([np.array(img),np.eye(2)[self.LABELS[label]]])
                if label == self.ART:
                    self.artcount += 1
                    print(self.artcount)
                elif label == self.NOTART:
                    self.notartcount += 1
            except Exception as e:
                pass
    #print(training_data)
    np.random.shuffle(self.training_data)
    np.save("training_data.npy",self.training_data)
    print("Art:",self.artcount)
    print("NotART",self.notartcount)
    

if REBUILD_DATA:
    artornot = ArtOrNot()
    artornot.make_training_data()

training_data = np.load("training_data.npy",allow_pickle=True)
print(len(training_data))

import matplotlib.pyplot as plt
plt.imshow(training_data[1][0])
plt.show()
import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
    super().__init__()
    self.conv1 = nn.Conv2d(1,32,5)
    self.conv2 = nn.Conv2d(32,64,5)
    self.conv3 = nn.Conv2d(64,128,5)
    
    
    x = torch.randn(50,50).view(-1,1,50,50)
    self._to_linear = None
    self.convs(x)
    
    
    
    self.fc1 = nn.Linear(self._to_linear,512)
    self.fc2 = nn.Linear(512,2)
    
    
def convs(self,x):
    x = F.max_pool2d(F.relu(self.conv1(x)),(2,2))
    x = F.max_pool2d(F.relu(self.conv2(x)),2))
    x = F.max_pool2d(F.relu(self.conv3(x)),2))
    print(x[0].shape)
    
    if self._to_linear is None:
        self._to_linear = x[0].shape[0]*x[0].shape[1]*x[0].shape[2]
        
    return x

def forward(self,x):
    x = self.convs(x)
    x = x.view(-1,self._to_linear)
    c = F.relu(self.fc1(x))
    x = self.fc2(x)
    #return F.softmax(x,dim =1)

net = Net() 
print(net)

import torch.optim as optim

optimizer = optim.Adam(net.parameters(),lr=0.001)
#loss_function = torch.nn.functional.mse_loss()


X = torch.Tensor([i[0] for i in training_data]).view(-1,50)
X = X/255.0
y = torch.Tensor([i[1] for i in training_data])

VAL_PCT = 0.2
val_size = int(len(X)*VAL_PCT)
print(val_size)

train_X = X[:-val_size]
#print(train_X)
train_y = X[:-val_size]
#print(train_y)
test_X = X[-val_size:]
print(test_X)
test_y = y[-val_size:]
print(test_y)
#print(X)
print(len(X))
#test_X = 8

#print(len(train_X),print(test_X))
#print(len(train_X))
#print(len(train_X))
#print(test_X)

BATCH_SIZE = 4
EPOCHS = 10

for epoch in range(EPOCHS):
    for i in tqdm(range(0,len(train_X),BATCH_SIZE)): # from 0,to the len of x,stepping BATCH_SIZE at a time. [:50] ..for now just to dev
        #print(f"{i}:{i+BATCH_SIZE}")
        batch_X = train_X[i:i+BATCH_SIZE].view(-1,50)
        batch_y = train_y[i:i+BATCH_SIZE]

        input = torch.randn(3,5,requires_grad=True)
    
        target = torch.randn(3,5)
        output = loss(input,target)
        output.backward()
        optimizer.step()
        print(output)
    #print(f"Epoch: {epoch}. Loss: {loss}")
#print(loss)
#print(len(train_X))
#print(len(test_X))
问题领域
correct = 0
total = 0

with torch.no_grad():
    for i in tqdm(range(len(test_X))):
        real_class = torch.argmax(test_y[i])
    
        net_out = net(test_X[i].view(10,50))[0]  # returns a list,predicted_class = torch.argmax(net_out)

        if predicted_class == real_class:
            correct += 1
        total += 1
print("Accuracy: ",round(correct/total,3))
错误消息:RuntimeError:形状'[10,50,1,50]'对于大小为2500的输入无效

所以我知道从神经网络出来的数据大小是2500,当我更改参数时,又遇到另一个错误,说它期望4D张量,而我只给了3D张量?有帮助吗?

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


依赖报错 idea导入项目后依赖报错,解决方案:https://blog.csdn.net/weixin_42420249/article/details/81191861 依赖版本报错:更换其他版本 无法下载依赖可参考:https://blog.csdn.net/weixin_42628809/a
错误1:代码生成器依赖和mybatis依赖冲突 启动项目时报错如下 2021-12-03 13:33:33.927 ERROR 7228 [ main] o.s.b.d.LoggingFailureAnalysisReporter : *************************** APPL
错误1:gradle项目控制台输出为乱码 # 解决方案:https://blog.csdn.net/weixin_43501566/article/details/112482302 # 在gradle-wrapper.properties 添加以下内容 org.gradle.jvmargs=-Df
错误还原:在查询的过程中,传入的workType为0时,该条件不起作用 <select id="xxx"> SELECT di.id, di.name, di.work_type, di.updated... <where> <if test=&qu
报错如下,gcc版本太低 ^ server.c:5346:31: 错误:‘struct redisServer’没有名为‘server_cpulist’的成员 redisSetCpuAffinity(server.server_cpulist); ^ server.c: 在函数‘hasActiveC
解决方案1 1、改项目中.idea/workspace.xml配置文件,增加dynamic.classpath参数 2、搜索PropertiesComponent,添加如下 <property name="dynamic.classpath" value="tru
删除根组件app.vue中的默认代码后报错:Module Error (from ./node_modules/eslint-loader/index.js): 解决方案:关闭ESlint代码检测,在项目根目录创建vue.config.js,在文件中添加 module.exports = { lin
查看spark默认的python版本 [root@master day27]# pyspark /home/software/spark-2.3.4-bin-hadoop2.7/conf/spark-env.sh: line 2: /usr/local/hadoop/bin/hadoop: No s
使用本地python环境可以成功执行 import pandas as pd import matplotlib.pyplot as plt # 设置字体 plt.rcParams['font.sans-serif'] = ['SimHei'] # 能正确显示负号 p
错误1:Request method ‘DELETE‘ not supported 错误还原:controller层有一个接口,访问该接口时报错:Request method ‘DELETE‘ not supported 错误原因:没有接收到前端传入的参数,修改为如下 参考 错误2:cannot r
错误1:启动docker镜像时报错:Error response from daemon: driver failed programming external connectivity on endpoint quirky_allen 解决方法:重启docker -> systemctl r
错误1:private field ‘xxx‘ is never assigned 按Altʾnter快捷键,选择第2项 参考:https://blog.csdn.net/shi_hong_fei_hei/article/details/88814070 错误2:启动时报错,不能找到主启动类 #
报错如下,通过源不能下载,最后警告pip需升级版本 Requirement already satisfied: pip in c:\users\ychen\appdata\local\programs\python\python310\lib\site-packages (22.0.4) Coll
错误1:maven打包报错 错误还原:使用maven打包项目时报错如下 [ERROR] Failed to execute goal org.apache.maven.plugins:maven-resources-plugin:3.2.0:resources (default-resources)
错误1:服务调用时报错 服务消费者模块assess通过openFeign调用服务提供者模块hires 如下为服务提供者模块hires的控制层接口 @RestController @RequestMapping("/hires") public class FeignControl
错误1:运行项目后报如下错误 解决方案 报错2:Failed to execute goal org.apache.maven.plugins:maven-compiler-plugin:3.8.1:compile (default-compile) on project sb 解决方案:在pom.
参考 错误原因 过滤器或拦截器在生效时,redisTemplate还没有注入 解决方案:在注入容器时就生效 @Component //项目运行时就注入Spring容器 public class RedisBean { @Resource private RedisTemplate<String
使用vite构建项目报错 C:\Users\ychen\work>npm init @vitejs/app @vitejs/create-app is deprecated, use npm init vite instead C:\Users\ychen\AppData\Local\npm-