我是一个Keras机器学习的新手。我正在尝试理解生成对抗网络(GAN)。为此,我正在编写一个简单的示例。我使用以下函数生成数据:
def genReal(l): realX = [] for i in range(l): x = [] y = [] for i in np.arange(0.0, 1.0, 0.02): x.append(i + np.random.normal(0,0.01)) y.append(-abs(i-0.5)+0.5+ np.random.normal(0,0.01)) data = np.array(list(zip(x, y))) data = np.reshape(data, (100)) data.clip(0,1) realX.append(data) realX = np.array(realX) return realX
使用此函数生成的数据看起来与以下示例相似:
现在,目标是训练一个神经网络来生成类似数据。对于GAN,我们需要一个生成器网络,我这样建模的:
generator = Sequential()generator.add(Dense(128, input_shape=(100,), activation='relu'))generator.add(Dropout(rate=0.2))generator.add(Dense(128, activation='relu'))generator.add(Dropout(rate=0.2))generator.add(Dense(100, activation='sigmoid'))generator.compile(loss='mean_squared_error', optimizer='adam')
还有一个判别器,看起来像这样:
discriminator = Sequential()discriminator.add(Dense(128, input_shape=(100,), activation='relu'))discriminator.add(Dropout(rate=0.2))discriminator.add(Dense(128, activation='relu'))discriminator.add(Dropout(rate=0.2))discriminator.add(Dense(1, activation='sigmoid'))discriminator.compile(loss='mean_squared_error', optimizer='adam')
组合模型如下:
ganInput = Input(shape=(100,))x = generator(ganInput)ganOutput = discriminator(x)GAN = Model(inputs=ganInput, outputs=ganOutput)GAN.compile(loss='binary_crossentropy', optimizer='adam')
我有一个生成噪声的函数(一个随机数组)
def noise(l): noise = np.array([np.random.uniform(0, 1, size=[l, ])]) return noise
然后我训练模型:
for i in range(1000000): fake = generator.predict(noise(100)) print(i, "==>", discriminator.predict(fake)) discriminator.train_on_batch(genReal(1), np.array([1])) discriminator.train_on_batch(fake, np.array([0])) discriminator.trainable = False GAN.train_on_batch(noise(100), np.array([1])) discriminator.trainable = True
如你所见,我已经尝试了100万次迭代来训练模型。但生成器在不同输入下输出的数据看起来是这样的:
这显然不是我想要的结果。所以我的问题是:100万次迭代还不够,还是我的程序概念有问题?
编辑:
这是我用于绘制数据的函数:
def plotData(data): x = np.reshape(data,(50,2)) x = x.tolist() plt.scatter(list(zip(*x))[0],list(zip(*x))[1], c=col)
回答:
你的实现问题在于discriminator.trainable = False
在编译discriminator
之后没有任何效果。因此,当你执行GAN.train_on_batch
时,所有的权重(包括判别器和生成器网络的权重)都是可训练的。
解决这个问题的方法是在编译discriminator
之后,编译GAN
之前立即设置discriminator.trainable = False
:
discriminator.compile(loss='mean_squared_error', optimizer='adam') discriminator.trainable = FalseganInput = Input(shape=(100,))x = generator(ganInput)ganOutput = discriminator(x)GAN = Model(inputs=ganInput, outputs=ganOutput)GAN.compile(loss='binary_crossentropy', optimizer='adam')