深度学习:GAN案例练习-minst手写数字

本文详细介绍了生成对抗网络(GAN)的理论与实践,包括目标、网络优化、训练技巧及代码实现。通过BP全连接网络和CNN结构,展示如何训练判别器和生成器,并探讨了损失函数、权重初始化和训练过程中的策略。在多个代码示例中,展示了生成手写数字的过程,以及训练过程中生成图片和损失函数的变化。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

理论

参考:GAN原理详解

目标

最终期望两个网络达到一个动态均衡:生成器生成的图像接近于真实图像分布,而判别器识别不出真假图像,对于给定图像的预测为真的概率基本接近0.5(相当于随机猜测类别)。
目标:
对于生成器来说,传给辨别器的生成图片,生成器希望辨别器打上标签1。因为它要不断训练减小损失,以期望骗过判别器。
对于判别器来说,给定的真实图片,辨别器要为其打上标签1;给定的生成图片,辨别器要为其打上标签0;它要能够识别真假。

优化网络(定义损失)

GAN有两个网络,那么自然就有两个损失函数。
**生成网络的损失函数:**制造一个可以瞒过识别网络的输出
代表判断结果与1的距离。
识别网络的损失函数:
实数据就是真实数据,生成数据就是虚假数据(即真实数据与1的距离小,生成数据与0的距离小)

训练过程

GAN对抗网络的训练过程通常是两个网络单独且交替训练:先训练识别网络,再训练生成网络,再训练识别网络,如此反复,直到达到纳什均衡。

1.当生成器损失从很大的值迅速变为0,而判别器损失维持不变。
有可能时生成器生成能力较弱,因此一种可行的方法是增加生成器的层数来增加非线性。

2.某些文献采用生成器与判别器交叉训练的方法,即先训练判别器,再训练生成器,其目的是先训练判别器并更新其参数,先让其具有较好判别能力,而在训练生成器时因为判别器已具有一定判定能力,生成器的目的是尽可能骗过判别器,所以生成器会朝着生成更真实的图像前进;
也可以采用先训练生成器,再训练判别器,但是此种训练方法不推荐;同时也可以采用先更新生成器或判别器多次,再更新另一个一次的方法。

  1. 生成器损失、判别器损失,其中一个很大或者逐渐变大,另一个很小或者逐渐变小。
  2. 生成器和判别器的目的相反,也就是说两个生成器网络和判别器网络互为对抗,此消彼长。不可能Loss一直降到一个收敛的状态。

对于生成器,其Loss下降快,很有可能是判别器太弱,导致生成器很轻易的就"愚弄"了判别器。

对于判别器,其Loss下降快,意味着判别器很强,判别器很强则说明生成器生成的图像不够逼真,才使得判别器轻易判别,导致Loss下降很快。

也就是说,无论是判别器,还是生成器。loss的高低不能代表生成器的好坏。一个好的GAN网络,其GAN Loss往往是不断波动的。

技巧

训练GAN技巧
1.输入的图片经过处理,将0-255的值变为-1到1的值。
images = (images/255.0)*2 - 1

2 在generator输出层使用tanh激励函数,使得输出范围在 (-1,1)

3 保存生成的图片时,将矩阵值缩放到[0,1]之间
gen_image = (gen_image+1) / 2

4 使用leaky_relu激励函数,使得负值可以有一定的比重

5 使用BatchNormalization,使分布更均匀,最后一层不要使用。

6 在训练generator和discriminator的时候,一定要保证另外一个的参数是不变的,不能同时更新两个网络的参数。

7 如果训练很容易卡住,可以考虑使用WGAN
可以选择使用RMSprop optimizer

代码1(保存生成图片、loss可视化)

参考:GAN手写数字
代码位置:E:\项目例程\GNN\手写数字\3_可视化
评价:代码解释少,生产图片效果可以,
可学习保存生成图片代码
代码(加了可视化loss):

import torch
import torch.nn as nn
from torch import optim
from torch.autograd import variable
from torch.utils.data.dataloader import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image
import matplotlib.pyplot as plt

G_in_dim = 100  # 模型的参数参考别人的网络设置
D_in_dim = 784
hidden1_dim = 256
hidden2_dim = 256
G_out_dim = 784
D_out_dim = 1

epoch = 50
batch_num = 60
lr_rate = 0.0003


def to_img(x):  # 这个函数参考自别人的网络,是将生成的假图像经过一系列操作能更清晰的显示出来,具体为什么这样设置没研究过
    out = 0.5 * (x + 1)
    out = out.clamp(0, 1)
    out = out.view(-1, 1, 28, 28)
    return out


class G_Net(nn.Module):  # 生成网络,或者叫生成器,负责生成假数据
    def __init__(self):
        super().__init__()
        self.layer = nn.Sequential(
            nn.Linear(G_in_dim, hidden1_dim),
            nn.ReLU(),
            nn.Linear(hidden1_dim, hidden2_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden2_dim, G_out_dim),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.layer(x)
        return x


class D_Net(nn.Module):  # 判别网络,或者叫判别器,用来判别数据真假
    def __init__(self):
        super().__init__()
        self.layer = nn.Sequential(
            nn.Linear(D_in_dim, hidden1_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden1_dim, hidden2_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden2_dim, D_out_dim),
            nn.Sigmoid())

    def forward(self, x):
        x = self.layer(x)
        return x


data_tf = transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize([0.5], [0.5])])
train_set = datasets.MNIST(root='data', train=True, transform=data_tf, download=True)
train_loader = DataLoader(train_set, batch_size=batch_num, shuffle=True)
g_net = G_Net()
d_net = D_Net()


G_losses = []
D_losses = []

criterion = nn.BCELoss()
G_optimizer = optim.Adam(g_net.parameters(), lr=lr_rate)
D_optimizer = optim.Adam(d_net.parameters(), lr=lr_rate)

iter_count = 0
for e in range(epoch):
    for data in train_loader:
        img, l = data
        img = img.view(img.size(0), -1)
        img = variable(img)
        r_label = variable(torch.ones(batch_num))
        f_label = variable(torch.zeros(batch_num))
        g_input = variable(torch.randn(batch_num, G_in_dim))

        r_output = d_net(img)
        r_loss = criterion(r_output.squeeze(-1), r_label)
        f_output = g_net(g_input)
        d_f_output = d_net(f_output)
        f_loss = criterion(d_f_output.squeeze(-1), f_label)
        sum_loss = r_loss + f_loss
        D_optimizer.zero_grad()
        sum_loss.backward()
        D_optimizer.step()

        g_input1 = variable(torch.randn
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值