GAN入门实践(一)--Tensorflow实现
最近在网上看到一个据说是 Alex Smola 写的关于生成对抗网络(Generative Adversarial Network, GAN)的入门教程,目的是从实践的角度讲解 GAN 的基本思想和实现过程。这个教程没有采用最普遍的用 GAN 生成图像作为例子,而是用生成一个特定的高斯分布作为切入点来讲解 GAN 的工作原理和效果。对 GAN 感兴趣的朋友不妨去看看这个简短的教程,直观易懂而不失有趣,链接如下:
https://github.com/zackchase/mxnet-the-straight-dope/blob/master/P10-C01-gan-intro.ipynb
这个教程的实现是基于 mxnet 的,而我最近比较常用的是 tensorflow 和 pytorch,因此花了点小时间把其中的内容用这两个框架实现了一遍。一方面是加深自己对 GAN 和 tensorflow / pytorch 的理解,同时也给感兴趣的朋友多一些参考。由于教程已经非常通俗易懂了,本文不会再重复其中的讲解,而是记录一些实现过程中个人觉得要注意的地方。
Tensorflow实现的全代码在这里,另外还有conditional GAN的实现在这里。 其中还参考了DCGAN的实现。
Pytorch篇请看这里。
模型定义
class GAN(object):
def __init__(self):
# input, output
self.z = tf.placeholder(tf.float32, shape=[None, 2], name='z')
self.x = tf.placeholder(tf.float32, shape=[None, 2], name='real_x')
# define the network
self.fake_x = self.netG(self.z)
self.real_logits = self.netD(self.x, reuse=False)
self.fake_logits = self.netD(self.fake_x, reuse=True)
......
def netD(self, x, reuse=False):
"""3-layer fully connected network"""
with tf.variable_scope("discriminator") as scope:
if reuse:
scope.reuse_variables()
W1 = tf.get_variable(name="d_W1", shape=[2, 5],
initializer=tf.contrib.layers.xavier_initializer(),
trainable=True)
......
Tensorflow 中对于变量(Variable)的创建有个需要特别主要的地方就是变量域(variable_scope)。当我们需要多次调用同一个网络(如上面的例子),或者多个网络共同分享其变量(比如权值共享)时,我们就可以把它们设成在同样的 variable_scope 下(当然变量名 name 也要是一样的),然后选择 reuse = True 实现变量的复用和共享。上面的代码就是一个例子,我们的 netD 只有一个,但是需要调用两次得到 real_logits 和 fake_logits。这时,我们第一次创建 netD 应该用 reuse=False,因为是创建一个新的网络和变量,而第二次调用时则选择 reuse=True,从而保证网络的变量是一样的,即用的同一个网络。
定义优化
class GAN(object):
def __init__(self):
......
# collect variables
t_vars = tf.trainable_variables()
self.d_vars = [var for var in t_vars if 'd_' in var.name]
self.g_vars = [var for var in t_vars if 'g_' in var.name]
......
gan = GAN()
d_optim = tf.train.AdamOptimizer(learning_rate=0.05).minimize(gan.loss_D, var_list=gan.d_vars)
g_optim = tf.train.AdamOptimizer(learning_rate=0.01).minimize(gan.loss_G, var_list=gan.g_vars)
生成对抗网络的优化过程是一个生成器 G 和判别器 D 相互搏弈的过程,因此我们对 G 和 D 分别定义其对应的优化方法。它们分别对应不同的损失函数和不同的变量,其中我们运用变量命名的小技巧将生成器和判别器的变量分别存到不同的列表当中。
模型训练
这个部分就很常规了。注意以下几点就ok了:
- tensorflow中要先做初始化
- 定义好minibatch的生成方式
- 每次迭代要生成一个随机数向量 z_batch
跟教程中的实现类似,我们在每个epoch结束后都将判别器的分类正确率打印出来看看。由于 GAN 的训练中,loss并不能作为很好的监测训练过程的标准,打印分类accuracy可以让我们比较直观地看到模型的训练效果(我们最终希望生成的数据能够fool判别器,所以希望准确率为50%)。