现在,让我们定义一些将在整个教程中使用的符号,从判别器开始。设 x 为表示图像的数据。$D(x)$ 是判别器网络,输出 x 来自训练数据而不是生成器的(标量)概率。这里,由于我们处理的是图像,$D(x)$ 的输入是 CHW 尺寸为 3x64x64 的图像。直观上,当 x 来自训练数据时,$D(x)$ 应该是高的,而当 x 来自生成器时,$D(x)$ 应该是低的。$D(x)$ 也可以视为传统的二分类器。
对于生成器的符号,设 z 为从标准正态分布中采样的潜在空间向量。G(z) 表示将潜在向量 z 映射到数据空间的生成器函数。$G$ 的目标是估计训练数据来自的分布 ($p_{data}$),以便从该估计分布中生成假样本 ($p_g$)。
因此,$D(G(z))$ 是生成器输出 G 为真实图像的概率(标量)。如 Goodfellow 的论文 中所描述,$D和 $G 进行一个极小极大博弈,其中 D 尽量最大化它正确分类真实和假的概率 (logD(x)$),而 $G 尽量最小化 D 预测其输出为假的概率 ($log(1-D(G(z)))$)。在这篇论文中,GAN 损失函数为
publicclassOptions { /// ///
Root directory for dataset /// publicstringDataroot{get;set;}="data/celeba"; /// /// Number of workers for dataloader /// publicintWorkers{get;set;}=2; /// /// Batch size during training /// publicintBatchSize{get;set;}=128; /// /// Spatial size of training images. All images will be resized to this size using a transformer. /// publicintImageSize{get;set;}=64; /// /// Number of channels in the training images. For color images this is 3 /// publicintNc{get;set;}=3; /// /// Size of z latent vector (i.e. size of generator input) /// publicintNz{get;set;}=100; /// /// Size of feature maps in generator /// publicintNgf{get;set;}=64; /// /// Size of feature maps in discriminator /// publicintNdf{get;set
;}=64; /// /// Number of training epochs /// publicintNumEpochs{get;set;}=5; /// /// Learning rate for optimizers /// publicdoubleLr{get;set;}=0.0002; /// /// Beta1 hyperparameter for Adam optimizers /// publicdoubleBeta1{get;set;}=0.5; /// /// Number of GPUs available. Use 0 for CPU mode. /// publicintNgpu{get;set;}=1; }
数据集处理
本教程中,我们将使用 Celeb-A Faces 数据集 来训练模型,可以从链接网站或在 Google Drive 下载。
var img_list =newList<Tensor>(); var G_losses =newList<double>(); var D_losses =newList<double>(); Console.WriteLine("Starting Training Loop..."); Stopwatch stopwatch =new(); stopwatch.Start(); int i =0; // For each epoch for(int epoch =0; epoch < options.NumEpochs; epoch++) { foreach(var item in dataloader) { var data = item[0]; netD.zero_grad(); // Format batch var real_cpu = data.to(defaultDevice); var b_size = real_cpu.size(0); var label = torch.full(newlong[]{ b_size }, real_label, dtype:ScalarType.Float32, device: defaultDevice); // Forward pass real batch through D var output = netD.forward(real_cpu); // Calculate loss on all-real batch var errD_real = criterion.call(output, label); // Calculate gradients for D in backward pass errD_real.backward(); var D_x = output.mean().item<float>(); // Train with all-fake batch // Generate batch of latent vectors var noise = torch.randn(newlong[]{ b_size, options.Nz,1,1}, device: defaultDevice); // Generate fake image batch with G var fake = netG.call(noise); label.fill_(fake_label); // Classify all fake batch with D output = netD.call(fake.detach()); // Calculate D's loss on the all-fake batch var errD_fake = criterion.call(output, label); // Calculate the gradients for this batch, accumulated (summed) with previous gradients errD_fake.backward(); var D_G_z1 = output.mean().item<float>(); // Compute error of D as sum over the fake and the real batches var errD = errD_real + errD_fake; // Update D optimizerD.step(); //////////////////////////// // (2) Update G network: maximize log(D(G(z))) //////////////////////////// netG.zero_grad(); label.fill_(real_label);// fake labels are real for generator cost // Since we just updated D, perform another forward pass of all-fake batch through D output = netD.call(fake); // Calculate G's loss based on this output var errG = criterion.call(output, label); // Calculate gradients for G errG.backward(); var D_G_z2 = output.mean().item<float>(); // Update G optimizerG.step(); // ex: [0/25][4/3166] Loss_D: 0.5676 Loss_G: 7.5972 D(x): 0.9131 D(G(z)): 0.3024 / 0.0007 Console.WriteLine($"[{epoch}/{options.NumEpochs}][{i%dataloader.Count}/{dataloader.Count}] Loss_D: {errD.item<float>():F4} Loss_G: {errG.item<float>():F4} D(x): {D_x:F4} D(G(z)): {D_G_z1:F4} / {D_G_z2:F4}"); // 每处理 100 批,输出一次图片效果 if(i %100==0) { real_cpu.SaveJpeg("samples/real_samples.jpg"); fake = netG.call(fixed_noise); fake.detach().SaveJpeg("samples/fake_samples_epoch_{epoch:D3}.jpg"); } i++; } netG.save("samples/netg_{epoch}.dat"); netD.save("samples/netd_{epoch}.dat"); }