如何在 DCGAN 中增加 image_size

如何解决如何在 DCGAN 中增加 image_size

我正在使用 DCGAN 来合成医学图像。不过目前Img_size是64,分辨率太低了。

如何更改生成器和鉴别器使分辨率达到 512*512?

下面是我的代码。

# Root directory for dataset
dataroot = "./Image/Knee/"

# Number of workers for dataloader
workers = 4

# Batch size during training
batch_size = 128

# Spatial size of training images. All images will be resized to this
#   size using a transformer.
image_size = 64

# Number of channels in the training images. For color images this is 3
nc = 3

# Size of z latent vector (i.e. size of generator input)
nz = 100

# Size of feature maps in generator
ngf = 64

# Size of feature maps in discriminator
ndf = 64

# Number of training epochs
num_epochs = 500

# Learning rate for optimizers
lr = 0.0002

# Beta1 hyperparam for Adam optimizers
beta1 = 0.5

# Number of GPUs available. Use 0 for CPU mode.
ngpu = 2

# We can use an image folder dataset the way we have it setup.
# Create the dataset
dataset = datasets.ImageFolder(root=dataroot,transform=transforms.Compose([
                               transforms.Resize(image_size),transforms.CenterCrop(image_size),transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5)),]))
# Create the dataloader
dataloader = DataLoader(dataset,batch_size=batch_size,shuffle=True,num_workers=workers)

# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

# Plot some training images
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(utils.make_grid(real_batch[0].to(device)[:64],padding=2,normalize=True).cpu(),(1,2,0)))

这是生成器代码。

# Generator Code

class Generator(nn.Module):
    def __init__(self,ngpu):
        super(Generator,self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z,going into a convolution
            nn.ConvTranspose2d( nz,ngf * 8,4,1,bias=False),nn.BatchNorm2d(ngf * 8),nn.ReLU(True),# state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8,ngf * 4,nn.BatchNorm2d(ngf * 4),# state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d( ngf * 4,ngf * 2,nn.BatchNorm2d(ngf * 2),# state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d( ngf * 2,ngf,nn.BatchNorm2d(ngf),# state size. (ngf) x 32 x 32
            nn.ConvTranspose2d( ngf,nc,nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self,input):
        return self.main(input)
    
# Create the generator
netG = Generator(ngpu).to(device)

# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
    netG = nn.DataParallel(netG,list(range(ngpu)))

# Apply the weights_init function to randomly initialize all weights
#  to mean=0,stdev=0.2.
netG.apply(weights_init)


print(device.type)
print(ngpu)
# Print the model
print(netG)

这里是鉴别器代码。

class Discriminator(nn.Module):
    def __init__(self,ngpu):
        super(Discriminator,self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc,ndf,nn.LeakyReLU(0.2,inplace=True),# state size. (ndf) x 32 x 32
            nn.Conv2d(ndf,ndf * 2,nn.BatchNorm2d(ndf * 2),# state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2,ndf * 4,nn.BatchNorm2d(ndf * 4),# state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4,ndf * 8,nn.BatchNorm2d(ndf * 8),# state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8,nn.Sigmoid()
        )

    def forward(self,input):
        return self.main(input)
    
# Create the Discriminator
netD = Discriminator(ngpu).to(device)

# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
    netD = nn.DataParallel(netD,list(range(ngpu)))
    
# Apply the weights_init function to randomly initialize all weights
#  to mean=0,stdev=0.2.
netD.apply(weights_init)

# Print the model
print(netD)

# Initialize BCELoss function
criterion = nn.BCELoss()

# Create batch of latent vectors that we will use to visualize
#  the progression of the generator
fixed_noise = torch.randn(64,nz,device=device)

# Establish convention for real and fake labels during training
real_label = 1
fake_label = 0

# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(),lr=lr,betas=(beta1,0.999))
optimizerG = optim.Adam(netG.parameters(),0.999))

我使用了基本的 DCGAN。我想改变网络结构生成512*512的高分辨率图片

这是损失函数和优化器

# Initialize BCELoss function
criterion = nn.BCELoss()

# Create batch of latent vectors that we will use to visualize
#  the progression of the generator
fixed_noise = torch.randn(64,0.999))

这是训练代码

# Training Loop

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0
num_epochs = 500
print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
    # For each batch in the dataloader
    for i,data in enumerate(dataloader,0):
        
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        netD.zero_grad()
        # Format batch
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,),real_label,device=device)
        # Forward pass real batch through D
        output = netD(real_cpu).view(-1)
        # Calculate loss on all-real batc
        output=output.float()
        label =label.float()
        #print(output.shape)
        #print(label.shape)
        errD_real = criterion(output,label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = output.mean().item()

        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size,device=device)
        # Generate fake image batch with G
        fake = netG(noise)
        label.fill_(fake_label)
        print(fake.detach())
        # Classify all fake batch with D
        output = netD(fake.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output,label)
        # Calculate the gradients for this batch
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # Add the gradients from the all-real and all-fake batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D,perform another forward pass of all-fake batch through D
        output = netD(fake).view(-1)
        # Calculate G's loss based on this output
        errG = criterion(output,label)
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()
        
        # Output training stats
        if i % 100 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch,num_epochs,i,len(dataloader),errD.item(),errG.item(),D_x,D_G_z1,D_G_z2))
        
        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())
        
        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(utils.make_grid(fake,normalize=True))
            
        iters += 1

如果我使用你给我的答案代码,我会在下面收到错误消息

运行时错误:输入类型(torch.cuda.FloatTensor)和权重类型(torch.FloatTensor)应该相同

请再看一遍好吗?

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


依赖报错 idea导入项目后依赖报错,解决方案:https://blog.csdn.net/weixin_42420249/article/details/81191861 依赖版本报错:更换其他版本 无法下载依赖可参考:https://blog.csdn.net/weixin_42628809/a
错误1:代码生成器依赖和mybatis依赖冲突 启动项目时报错如下 2021-12-03 13:33:33.927 ERROR 7228 [ main] o.s.b.d.LoggingFailureAnalysisReporter : *************************** APPL
错误1:gradle项目控制台输出为乱码 # 解决方案:https://blog.csdn.net/weixin_43501566/article/details/112482302 # 在gradle-wrapper.properties 添加以下内容 org.gradle.jvmargs=-Df
错误还原:在查询的过程中,传入的workType为0时,该条件不起作用 <select id="xxx"> SELECT di.id, di.name, di.work_type, di.updated... <where> <if test=&qu
报错如下,gcc版本太低 ^ server.c:5346:31: 错误:‘struct redisServer’没有名为‘server_cpulist’的成员 redisSetCpuAffinity(server.server_cpulist); ^ server.c: 在函数‘hasActiveC
解决方案1 1、改项目中.idea/workspace.xml配置文件,增加dynamic.classpath参数 2、搜索PropertiesComponent,添加如下 <property name="dynamic.classpath" value="tru
删除根组件app.vue中的默认代码后报错:Module Error (from ./node_modules/eslint-loader/index.js): 解决方案:关闭ESlint代码检测,在项目根目录创建vue.config.js,在文件中添加 module.exports = { lin
查看spark默认的python版本 [root@master day27]# pyspark /home/software/spark-2.3.4-bin-hadoop2.7/conf/spark-env.sh: line 2: /usr/local/hadoop/bin/hadoop: No s
使用本地python环境可以成功执行 import pandas as pd import matplotlib.pyplot as plt # 设置字体 plt.rcParams['font.sans-serif'] = ['SimHei'] # 能正确显示负号 p
错误1:Request method ‘DELETE‘ not supported 错误还原:controller层有一个接口,访问该接口时报错:Request method ‘DELETE‘ not supported 错误原因:没有接收到前端传入的参数,修改为如下 参考 错误2:cannot r
错误1:启动docker镜像时报错:Error response from daemon: driver failed programming external connectivity on endpoint quirky_allen 解决方法:重启docker -> systemctl r
错误1:private field ‘xxx‘ is never assigned 按Altʾnter快捷键,选择第2项 参考:https://blog.csdn.net/shi_hong_fei_hei/article/details/88814070 错误2:启动时报错,不能找到主启动类 #
报错如下,通过源不能下载,最后警告pip需升级版本 Requirement already satisfied: pip in c:\users\ychen\appdata\local\programs\python\python310\lib\site-packages (22.0.4) Coll
错误1:maven打包报错 错误还原:使用maven打包项目时报错如下 [ERROR] Failed to execute goal org.apache.maven.plugins:maven-resources-plugin:3.2.0:resources (default-resources)
错误1:服务调用时报错 服务消费者模块assess通过openFeign调用服务提供者模块hires 如下为服务提供者模块hires的控制层接口 @RestController @RequestMapping("/hires") public class FeignControl
错误1:运行项目后报如下错误 解决方案 报错2:Failed to execute goal org.apache.maven.plugins:maven-compiler-plugin:3.8.1:compile (default-compile) on project sb 解决方案:在pom.
参考 错误原因 过滤器或拦截器在生效时,redisTemplate还没有注入 解决方案:在注入容器时就生效 @Component //项目运行时就注入Spring容器 public class RedisBean { @Resource private RedisTemplate<String
使用vite构建项目报错 C:\Users\ychen\work>npm init @vitejs/app @vitejs/create-app is deprecated, use npm init vite instead C:\Users\ychen\AppData\Local\npm-