RuntimeError: mat1形状和mat2形状不能相乘

如何解决RuntimeError: mat1形状和mat2形状不能相乘 ?

这一行

x = x.view( -1,x.size( 1 ))

意味着您保留第二维(通道)不变,而将其他所有内容放在第一维(批处理)。

由于self.encoder输出(1,48,4,4),这样做意味着您将获得(64,48),但从外观上看,我认为您需要(1,3072)

所以这应该可以解决这个特殊的问题。

x = x.view(x.size(0),-1)

然后你会遇到RuntimeError: unflatten: Provided sizes [48,4] don't multiply up to the size of dim 1 (3072) in the input tensor

原因是这里的未扁平化

nn.Linear(100,3072),nn.Unflatten(1,(48,4)),nn.ConvTranspose3d(48,32,2,1)

必须改为(48,4)

解决方法

我正在尝试输入一个形状为( 1,8,32,32,32 )的5D张量到我写的VAE:

self.encoder = nn.Sequential(
        nn.Conv3d( 8, 16, 4, 2, 1 ), # 32 -> 16
        nn.BatchNorm3d( 16 ), 
        nn.LeakyReLU( 0.2 ),
        
        nn.Conv3d( 16, 32, 4, 2, 1 ), # 16 -> 8
        nn.BatchNorm3d( 32 ),
        nn.LeakyReLU( 0.2 ),
        
        nn.Conv3d( 32, 48, 4, 2, 1 ), # 16 -> 4
        nn.BatchNorm3d( 48 ),
        nn.LeakyReLU( 0.2 ), 
    )
    
    self.fc_mu = nn.Linear( 3072, 100 ) # 48*4*4*4 = 3072
    self.fc_logvar = nn.Linear( 3072, 100 )
    
self.decoder = nn.Sequential(
    nn.Linear( 100, 3072 ),
    nn.Unflatten( 1, ( 48, 4, 4 )),
    nn.ConvTranspose3d( 48, 32, 4, 2, 1 ), # 4 -> 8
    nn.BatchNorm3d( 32 ),
    nn.Tanh(),
        
    nn.ConvTranspose3d( 32, 16, 4, 2, 1 ), # 8 -> 16
    nn.BatchNorm3d( 16 ),
    nn.Tanh(),
        
    nn.ConvTranspose3d( 16, 8, 4, 2, 1 ), # 16 -> 32
    nn.BatchNorm3d( 8 ),
    nn.Tanh(), 
)

def reparametrize( self, mu, logvar ):
    std = torch.exp( 0.5 * logvar )
    eps = torch.randn_like(  std )
    return mu + eps * std 

def encode( self, x ) :
    x = self.encoder( x )
    x = x.view( -1, x.size( 1 ))
    
    mu = self.fc_mu( x )
    logvar = self.fc_logvar( x )
    
    return self.reparametrize( mu, logvar ), mu, logvar 
    
def decode( self, x ):
    return self.decoder( x )
    
def forward( self, data ):
    z, mu, logvar = self.encode( data )
    return self.decode( z ), mu, logvar 

我得到的错误是:RuntimeError: mat1 and mat2 shapes cannot be multiplied (64x48 and 3072x100)。我认为我已经正确地计算了每一层的输出尺寸,但我一定是弄错了,但我不确定在哪里。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

其他编程问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?
Java在半透明框架/面板/组件上重新绘画。
Java“ Class.forName()”和“ Class.forName()。newInstance()”之间有什么区别?