个人GAN训练日志

2025-05-30 06:49:59
1 GAN GAN (Generative Adversarial Network) ,即生成对抗网络,曾经是深度学习的主流生成式网络架构,虽然近些年来Diffusion逐渐崛起,但GAN的思想确实有...

1 GAN

GAN (Generative Adversarial Network) ,即生成对抗网络,曾经是深度学习的主流生成式网络架构,虽然近些年来Diffusion逐渐崛起,但GAN的思想确实有着精妙的独到之处。 对于一个生成式任务而言,其目标无非是利用神经网络的模拟和建模能力,从一个简单的分布拟合成一个复杂的分布,从而满足“创造性”这一需求。

2 GAN原理

GAN模型由两个神经网络组成:生成器G和判别器D。生成器以随机噪声为输入并生成虚假的数据样本,而判别器则接收真实数据和生成的虚假数据作为输入,并尝试将它们区分开来。生成器的目标是生成越来越逼真的假数据,使判别器无法区分真实数据和生成的假数据,而判别器的目标则是尽可能准确地区分真实数据和生成的假数据。这种“对抗”训练使得生成器和判别器逐渐达到平衡,并最终生成高质量的数据样本。 个人认为,GAN的精髓在于巧妙的生成器G与判别器D的对抗设计,从而让生成器G能够逐渐了解与贴近复杂的数据分布。 一方面,利用判别器D从而大大简化了对于生成数据的评价问题,从而在原理上“轻松”地设计损失函数,使得生成器G能端到端地进行无监督训练(指没有标注的非条件GAN);另一方面,如果将G+D视为一个网络的话,那么GAN就是利用判别器D和生成器G两个模块的交错式训练,从而你一拳我一脚,直接自己左脚踩右脚互相提升,直到最后达到纳什均衡。 具体到训练过程上,简单而言就是以下步骤: 1.锁住生成器G的梯度,用真实图片和G产生的图片训练判别器D,使其具备分辨真假图片的能力。 2.锁住判别器D的梯度,利用假数据+真标签的方法训练生成器G,使得生成器G的参数朝着欺骗目前的判别器方向优化。 3.重复上述过程,判别器D和生成器G永远利用对方没有训练变强的空隙提升自己,打败对方,从而交错式成长。 从上述思想看来,事实上GAN是一种思想而非一种固定的网络结构,只要是梯度能传导,就意味着任何两种网络可以利用这种思想进行最终实现生成器G的训练。 当然,理论是美好的,但事实上,GAN的训练极度不稳定!因为涉及到两个网络的平衡问题,一旦判别器D过强或者生成器G找到了判别器的盲点,就无法继续提升了。而且我始终认为,判别这一任务远比生成简单,事实上在训练过程中,也常常出现判别器提升过快,不得不重置判别器D,使得生成器G有继续进步的空间。 而在理论上,WGAN的大佬也证明了GAN训练之难的理论背景,大致上因为图片的分布及其狭窄,高维空间中绝大部分都是噪声而非图片,导致生成数据与真实分布之间的重叠区域过小或不存在,JS难以进行优化,并且利用推土机距离优化原有了的JS距离判断。

3 DCGAN

怎么能不去亲自玩一下GAN呢,直接利用pytorch官方的DCGAN教程上手一下GAN,我这里采用了Arvin Liu收集的cripko数据集,都是动漫二次元头像。 数据集展示:

接下来就直接上代码,反正基本就是pytorch上copy下来,然后改了改路径。

3.1 数据集分割

import os

import shutil

import random

def run():

original_path="../../../../Dataset/AnimeFaces"

filename=os.listdir(original_path)

filename.remove("train")

filename.remove("test")

test_list=list(random.sample(filename,1500))

train_list=list(filter(lambda x: x not in test_list,filename))

#分割数据集

for i in range(len(train_list)):

src=original_path+"/"+train_list[i]

dst=original_path+"/trainfolder/train/"+train_list[i]

shutil.move(src,dst)

for i in range(len(test_list)):

src=original_path+"/"+test_list[i]

dst=original_path+"/test/"+test_list[i]

shutil.move(src,dst)

#run()

3.2 Dataset

这里训练图片存放在…/trainfolder/train,而不是…/trainfolder下

import os

import torch

import torchvision.datasets as Dataset

import torchvision.transforms as transforms

import numpy as np

dataroot="../../../../Dataset/AnimeFaces/trainfolder"

batch_size=256

dataset=Dataset.ImageFolder(root=dataroot,

transform=transforms.Compose([

transforms.Resize((64,64)),

transforms.ToTensor(),

transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),

]))

dataloader=torch.utils.data.DataLoader(dataset,batch_size=batch_size,shuffle=True,num_workers=1)

3.3 网络

用现在的眼光看,DCGAN还真是简单粗暴啊。

import torch.nn as nn

class Generator(nn.Module):

def __init__(self,z_dim=100,g_feature=64):

super(Generator, self).__init__()

self.net=nn.Sequential(

nn.ConvTranspose2d(z_dim,g_feature*8,4,1,0,bias=False),

nn.BatchNorm2d(g_feature*8),

nn.ReLU(True),

nn.ConvTranspose2d(g_feature*8, g_feature * 4, 4, 2, 1, bias=False),

nn.BatchNorm2d(g_feature * 4),

nn.ReLU(True),

nn.ConvTranspose2d(g_feature * 4, g_feature * 2, 4, 2, 1, bias=False),

nn.BatchNorm2d(g_feature * 2),

nn.ReLU(True),

nn.ConvTranspose2d(g_feature * 2, g_feature, 4, 2, 1, bias=False),

nn.BatchNorm2d(g_feature),

nn.ReLU(True),

nn.ConvTranspose2d(g_feature,3,4,2,1,bias=False),

nn.Tanh()

)

def forward(self,input):

return self.net(input)

class Discriminator(nn.Module):

def __init__(self,d_feature=64):

super(Discriminator, self).__init__()

self.net=nn.Sequential(

nn.Conv2d(3,d_feature,4,2,1,bias=False),

nn.LeakyReLU(0.2,inplace=True),

nn.Conv2d(d_feature, d_feature*2, 4, 2, 1, bias=False),

nn.BatchNorm2d(d_feature*2),

nn.LeakyReLU(0.2, inplace=True),

nn.Conv2d(d_feature*2, d_feature * 4, 4, 2, 1, bias=False),

nn.BatchNorm2d(d_feature * 4),

nn.LeakyReLU(0.2, inplace=True),

nn.Conv2d(d_feature*4, d_feature * 8, 4, 2, 1, bias=False),

nn.BatchNorm2d(d_feature * 8),

nn.LeakyReLU(0.2, inplace=True),

nn.Conv2d(d_feature*8,1,4,1,0,bias=False),

nn.Sigmoid()

)

def forward(self,input):

return self.net(input)

3.4 训练

输出一些特征,来观察训练过程。

import torch.nn as nn

import torch

from Dataset import dataloader

import torch.optim as optim

from Net import Generator,Discriminator

import numpy as np

import time

if __name__=='__main__':

device=torch.device("cuda")

criterion=nn.BCELoss()

fixed_noise=torch.randn(64,100,1,1,device=device)

real_label=1.

fake_label=0.

netG=Generator().to(device)

netD=Discriminator().to(device)

netG.load_state_dict(torch.load("model_parm/G_epoch600.pt"))

netD.load_state_dict(torch.load("model_parm/D_epoch400.pt"))

D_lr = 2e-6

G_lr = 2e-6

optimizerD=optim.Adam(netD.parameters(),lr=D_lr,betas=(0.5,0.999))

optimizerG=optim.Adam(netG.parameters(),lr=G_lr,betas=(0.5,0.999))

# Training Loop

# Lists to keep track of progress

G_losses = []

D_losses = []

num_epochs=1000

t1=time.time()

print("Starting Training Loop...")

# For each epoch

for epoch in range(601,num_epochs+1):

# For each batch in the dataloader

for i, data in enumerate(dataloader, 0):

G_losses = []

D_losses = []

############################

# (1) 训练判别器D : 最大化 log(D(x)) + log(1 - D(G(z))),即真图->1,假图->0

###########################

## 首先用真图进行训练

netD.zero_grad()

# 制作标签,全为1

real_cpu = data[0].to(device)

b_size = real_cpu.size(0)

# 带噪声的softlabel

# label = np.random.rand(b_size)*0.8+0.15

# label = torch.tensor(label,dtype=torch.float,device=device)

label = torch.full((b_size,), real_label, dtype=torch.float, device=device)

# 输出结果

output = netD(real_cpu).view(-1)

# 计算损失

errD_real = criterion(output, label)

# 反向传播

errD_real.backward()

D_x = output.mean().item()

## 使用假图训练

# 生成随机分布

noise = torch.randn(b_size,100 , 1, 1, device=device)

# 制作假图与标签

fake = netG(noise)

# 带噪声的softlabel

# label = np.random.rand(b_size)*0.1+0.05

# label = torch.tensor(label,dtype=torch.float,device=device)

label.fill_(fake_label)

# 输出结果

output = netD(fake.detach()).view(-1)

# C计算损失

errD_fake = criterion(output, label)

# 梯度回传

errD_fake.backward()

# D_G_z1代表着未更新判别器D前,生成器G对目前判别器D的对抗能力

D_G_z1 = output.mean().item()

# 计算总损失

errD = errD_real + errD_fake

# 优化

optimizerD.step()

############################

# (2) 训练生成器G : 最大化 log(D(G(z))),从而骗过D

###########################

netG.zero_grad()

# 假图片配真标签,从而使得更新G参数后,所生成的图片的标签向真标签靠近

# 使用hardlabel

G_label = torch.full((b_size,), real_label, dtype=torch.float, device=device)

# 假图片训练

output = netD(fake).view(-1)

# 损失计算

errG = criterion(output, G_label)

# 梯度回传

errG.backward()

# D_G_z2代表着未更新判别器G前,生成器G对目前已经更新后的判别器D的对抗能力

# 显然,一般情况下 D_G_z2 < D_G_z1

D_G_z2 = output.mean().item()

# 优化G

optimizerG.step()

# 记录损失

G_losses.append(errG.item())

D_losses.append(errD.item())

# save model

if epoch % 50 ==0:

torch.save(netG.state_dict(), 'model_parm/G_epoch' + str(epoch) + '.pt')

torch.save(netD.state_dict(), 'model_parm/D_epoch' + str(epoch) + '.pt')

# Output training stats

if epoch % 20 == 0:

print('[%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'

% (epoch, num_epochs, errD.item(), errG.item(), D_x, D_G_z1, D_G_z2) + ' cost time : ' + str(round(time.time()-t1,4))+'s')

t1=time.time()

if epoch<=605:

print('[%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'

% (epoch, num_epochs, errD.item(), errG.item(), D_x, D_G_z1, D_G_z2) + ' cost time : ' + str(round(time.time()-t1,4))+'s')

t1=time.time()

# 初始化判别器D

# if epoch % 15 ==0:

# netD.load_state_dict(torch.load("model_parm/D_epoch0.pt"))

# record train log

Gloss=np.mean(G_losses)

Dloss=np.mean(D_losses)

with open('model_parm/train_log2.txt','a+') as f:

string=str(epoch)+'\t'+str(round(Gloss,5))+'\t'+str(round(Dloss,5))+'\t'+\

str(round(D_x,5))+'\t'+str(round(D_G_z1,5))+'\t'+str(round(D_G_z2,5))+'\n'

f.write(string)

3.5 验证

除了人眼观察,当然就得使用FID指标了

from Net import Generator

import torch

import numpy as np

import matplotlib.pyplot as plt

import torchvision.utils as vutil

import os

import shutil

import re

def generate_img(epoch):

device=torch.device("cuda")

netG=Generator().to(device)

netG.load_state_dict(torch.load("model_parm/G_epoch"+str(epoch)+".pt"))

fixed_noise=torch.randn(64,100,1,1,device=device)

with torch.no_grad():

fake=netG(fixed_noise).detach().cpu()

plt.imshow(np.transpose(vutil.make_grid(fake,padding=2,normalize=True),(1,2,0)))

plt.show()

def cal_FID(epoch):

#生成图片

device=torch.device("cuda")

netG=Generator().to(device)

netG.load_state_dict(torch.load("model_parm/G_epoch"+str(epoch)+".pt"))

fixed_noise=torch.randn(1000,100,1,1,device=device)

fake=netG(fixed_noise).detach().cpu()

for i in range(1000):

vutil.save_image(fake[i],"D:/Pycharm/Dataset/AnimeFaces/fakeimg/"+str(i)+".jpg",normalize=True)

os.system("activate pytorch")

result=os.popen(r"python -m pytorch_fid D:\Pycharm\Dataset\AnimeFaces\test D:\Pycharm\Dataset\AnimeFaces\fakeimg")

content=result.readlines()[0]

fid=re.findall(r"\d+\.?\d*",content)

fid=list(filter(lambda x : x!='0',fid))

if(len(fid)==1):

print("epoch",epoch,":",float(fid[0]))

else:

print("epoch",epoch,":",fid)

shutil.rmtree(r"D:\Pycharm\Dataset\AnimeFaces\fakeimg")

os.mkdir(r"D:\Pycharm\Dataset\AnimeFaces\fakeimg")

generate_img(400)

4 WGAN

使用WGAN原因直接参考论文,我就没打算使用MLP硬整,直接就在DCGAN基础上改网络、损失函数和训练过程了。

4.1 Net

只把判别器D的sigmoid删了罢了。

class Discriminator(nn.Module):

def __init__(self,d_feature=64):

super(Discriminator, self).__init__()

self.net=nn.Sequential(

nn.Conv2d(3,d_feature,4,2,1,bias=False),

nn.LeakyReLU(0.2,inplace=True),

nn.Conv2d(d_feature, d_feature*2, 4, 2, 1, bias=False),

nn.BatchNorm2d(d_feature*2),

nn.LeakyReLU(0.2, inplace=True),

nn.Conv2d(d_feature*2, d_feature * 4, 4, 2, 1, bias=False),

nn.BatchNorm2d(d_feature * 4),

nn.LeakyReLU(0.2, inplace=True),

nn.Conv2d(d_feature*4, d_feature * 8, 4, 2, 1, bias=False),

nn.BatchNorm2d(d_feature * 8),

nn.LeakyReLU(0.2, inplace=True),

nn.Conv2d(d_feature*8,1,4,1,0,bias=False),

#删除sigmoid防止梯度消失

#nn.Sigmoid()

)

def forward(self,input):

return self.net(input)

4.2 Train

from Dataset import dataloader

import torch.optim as optim

from Net import Generator,Discriminator

import time

import torch

if __name__=='__main__':

device=torch.device("cuda")

#criterion=nn.BCELoss()

fixed_noise=torch.randn(64,100,1,1,device=device)

n_iter=5

clip_value=0.01

netG=Generator().to(device)

netD=Discriminator().to(device)

# netG.load_state_dict(torch.load("model_parm/G_epoch340.pt"))

# netD.load_state_dict(torch.load("model_parm/D_epoch340.pt"))

D_lr = 3e-5

G_lr = 3e-5

optimizerD=optim.RMSprop(netD.parameters(),lr=D_lr)

optimizerG=optim.RMSprop(netG.parameters(),lr=G_lr)

# Training Loop

num_epochs=1500

t1=time.time()

print("Starting Training Loop...")

# For each epoch

for epoch in range(num_epochs+1):

for i, data in enumerate(dataloader,0):

# Configure input

real_cpu = data[0].to(device)

b_size = real_cpu.size(0)

# ---------------------

# Train Discriminator

# ---------------------

optimizerD.zero_grad()

# Sample noise as generator input

noise = torch.randn(b_size, 100, 1, 1, device=device)

# Generate a batch of images

fake_imgs = netG(noise).detach()

# Adversarial loss

loss_D = -torch.mean(netD(real_cpu)) + torch.mean(netD(fake_imgs))

loss_D.backward()

optimizerD.step()

# Clip weights of discriminator

for p in netD.parameters():

p.data.clamp_(-clip_value, clip_value)

# Train the generator every n_critic iterations

# -----------------

# Train Generator

# -----------------

optimizerG.zero_grad()

# Generate a batch of images

gen_imgs = netG(noise)

# Adversarial loss

loss_G = -torch.mean(netD(gen_imgs))

loss_G.backward()

optimizerG.step()

# save model

if epoch % 50 ==0:

torch.save(netG.state_dict(), 'model_parm/G2_epoch' + str(epoch) + '.pt')

torch.save(netD.state_dict(), 'model_parm/D2_epoch' + str(epoch) + '.pt')

# Output training stats

if epoch % 20 == 0:

print('[%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\t'

% (epoch, num_epochs, loss_D.item(), loss_G.item()) + ' cost time : ' + str(round(time.time()-t1,4))+'s')

t1=time.time()

# 初始化判别器D

# if epoch % 15 ==0:

# netD.load_state_dict(torch.load("model_parm/D_epoch0.pt"))

# record train log

with open('model_parm/train_log2.txt','a+') as f:

string=str(epoch)+'\t'+str(round(loss_G.item(),5))+'\t'+str(round(loss_D.item(),5))+'\n'

f.write(string)

5 结果

没啥条理地训练了4轮吧,最后感觉还是原始的DCGAN效果最好,什么softlabel,WGAN都什么提升,但确实训练很随意,每个训练的epoch不一样,没怎么记录学习率改变啊,还有判别器和生成的回溯标准等等,纯纯地上手感受GAN罢了。

5.1 原始DCGAN

总共训练了1000epoch,中途在600epoch左右判别器宕机了,将判别器回溯到400epoch继续训练到1000epoch,直接上最好的结果(400epoch):

epochFID↓37084.9240079.9450092.3760093.45……90088.99100089.66

5.2 DCGAN+SoftLabel+判别器间歇更新

训练350epoch,最好结果(350epoch)如下:

epochFID↓0345.39150271.35350164.42

5.3 WGAN+生成器间歇更新

不得不承认,虽然WGAN最后效果一般,但训练过程基本都是稳步下降,也没怎么发生模式坍塌,训练1100epoch,最好结果(1100epoch)如下:

epochFID↓0386.23100284.69200218.54300185.20……900154.901000154.881100148.92

5.4 WGAN

训练1500epoch,最好结果(1400epoch)如下:

epochFID↓0306.68100142.88200125.99300124.07……1300113.031400111.631500112.65

6 总结

总之GAN还是很好玩的,而且不咋吃显存,我用了256的batchsize,也只吃了不到2G的显存。相比之下仅仅是微调SD的Lora模型,batchsize=1都要吃8G显存,果然只有scale matters。另外Markdown插入图片的体验太烂了。